diff --git a/.github/workflows/ado-net-tests.yml b/.github/workflows/ado-net-tests.yml new file mode 100644 index 00000000..11749d17 --- /dev/null +++ b/.github/workflows/ado-net-tests.yml @@ -0,0 +1,64 @@ +on: + push: + branches: [ main ] + pull_request: +permissions: + contents: read + pull-requests: write +name: ADO.NET Tests +jobs: + units: + strategy: + matrix: + dotnet-version: ['8.0.x', '9.0.x'] + os: [ubuntu-latest, macos-latest, windows-latest] + runs-on: ${{ matrix.os }} + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Install dotnet + uses: actions/setup-dotnet@v4 + with: + dotnet-version: ${{ matrix.dotnet-version }} + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: '1.25' + # Install compilers for cross-compiling between operating systems. + - name: Install compilers + run: | + echo "$RUNNER_OS" + if [ "$RUNNER_OS" == "Windows" ]; then + echo "Windows does not yet support cross compiling" + elif [ "$RUNNER_OS" == "macOS" ]; then + brew tap SergioBenitez/osxct + brew install x86_64-unknown-linux-gnu + brew install mingw-w64 + else + sudo apt-get update + sudo apt install -y g++-mingw-w64-x86-64 gcc-mingw-w64-x86-64 + sudo apt-get install -y gcc-arm-linux-gnueabihf + fi + shell: bash + - name: Build the .NET wrapper + working-directory: spannerlib/wrappers/spannerlib-dotnet + run: | + echo "$RUNNER_OS" + ./build.sh + shell: bash + - name: Restore .NET wrapper dependencies + run: dotnet restore + working-directory: spannerlib/wrappers/spannerlib-dotnet + shell: bash + - name: Build .NET wrapper + run: dotnet build --no-restore -c Release + working-directory: spannerlib/wrappers/spannerlib-dotnet + shell: bash + - name: spanner-ado-net-tests + working-directory: drivers/spanner-ado-net/spanner-ado-net-tests + run: dotnet test --verbosity normal + shell: bash + - name: spanner-ado-net-specification-tests + working-directory: drivers/spanner-ado-net/spanner-ado-net-specification-tests + run: dotnet test --verbosity normal + shell: bash diff --git a/drivers/spanner-ado-net/.gitignore b/drivers/spanner-ado-net/.gitignore new file mode 100644 index 00000000..808112cd --- /dev/null +++ b/drivers/spanner-ado-net/.gitignore @@ -0,0 +1,4 @@ +.idea +obj +bin +*DotSettings.user diff --git a/drivers/spanner-ado-net/LICENSE b/drivers/spanner-ado-net/LICENSE new file mode 100644 index 00000000..d6456956 --- /dev/null +++ b/drivers/spanner-ado-net/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/drivers/spanner-ado-net/README.md b/drivers/spanner-ado-net/README.md new file mode 100644 index 00000000..fbcfda13 --- /dev/null +++ b/drivers/spanner-ado-net/README.md @@ -0,0 +1,5 @@ +# Spanner ADO.NET Data Provider + +ADO.NET Data Provider for Spanner. + +__ALPHA: Not for production use__ diff --git a/drivers/spanner-ado-net/global.json b/drivers/spanner-ado-net/global.json new file mode 100644 index 00000000..2ddda36c --- /dev/null +++ b/drivers/spanner-ado-net/global.json @@ -0,0 +1,7 @@ +{ + "sdk": { + "version": "8.0.0", + "rollForward": "latestMinor", + "allowPrerelease": false + } +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-benchmarks/README.md b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/README.md new file mode 100644 index 00000000..e7015b23 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/README.md @@ -0,0 +1,5 @@ +# Spanner ADO.NET Data Provider Benchmarks + +Benchmarks for the ADO.NET Data Provider for Spanner. + +__ALPHA: Not for production use__ diff --git a/drivers/spanner-ado-net/spanner-ado-net-benchmarks/deploy.txt b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/deploy.txt new file mode 100644 index 00000000..b08e5498 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/deploy.txt @@ -0,0 +1,46 @@ + +gcloud run deploy spannerlib-dotnet-benchmark-tpcc \ + --region=europe-north1 \ + --no-allow-unauthenticated --no-cpu-throttling \ + --min-instances=1 --max-instances=1 \ + --cpu=4 --memory=2Gi \ + --set-env-vars=NUM_WAREHOUSES=100,TRANSACTIONS_PER_SECOND=50,NUM_CLIENTS=50 \ + --base-image dotnet8 \ + --source . + +gcloud run deploy spannerlib-dotnet-benchmark-tpcc \ + --region=europe-north1 \ + --no-allow-unauthenticated --no-cpu-throttling \ + --min-instances=1 --max-instances=1 \ + --cpu=4 --memory=2Gi \ + --set-env-vars=NUM_WAREHOUSES=100,TRANSACTIONS_PER_SECOND=50,NUM_CLIENTS=50,RETRY_ABORTS_INTERNALLY=false \ + --base-image dotnet8 \ + --source . + + +gcloud run deploy spannerlib-dotnet-benchmark-tpcc \ + --region=europe-north1 \ + --no-allow-unauthenticated --no-cpu-throttling \ + --min-instances=1 --max-instances=1 \ + --cpu=4 --memory=2Gi \ + --set-env-vars=NUM_WAREHOUSES=100,CLIENT_TYPE=ClientLib,TRANSACTIONS_PER_SECOND=50,NUM_CLIENTS=50 \ + --base-image dotnet8 \ + --source . + +gcloud run deploy spannerlib-dotnet-benchmark-tpcc \ + --region=europe-north1 \ + --no-allow-unauthenticated --no-cpu-throttling \ + --min-instances=1 --max-instances=1 \ + --cpu=4 --memory=2Gi \ + --set-env-vars=NUM_WAREHOUSES=100,CLIENT_TYPE=NativeSpannerLib,TRANSACTIONS_PER_SECOND=50,NUM_CLIENTS=50 \ + --base-image dotnet8 \ + --source . + +gcloud run deploy spannerlib-dotnet-benchmark-tpcc \ + --region=europe-north1 \ + --no-allow-unauthenticated --no-cpu-throttling \ + --min-instances=1 --max-instances=1 \ + --cpu=4 --memory=2Gi \ + --set-env-vars=NUM_WAREHOUSES=100,CLIENT_TYPE=NativeSpannerLib,TRANSACTIONS_PER_SECOND=50,NUM_CLIENTS=50,RETRY_ABORTS_INTERNALLY=false \ + --base-image dotnet8 \ + --source . diff --git a/drivers/spanner-ado-net/spanner-ado-net-benchmarks/spanner-ado-net-benchmarks.csproj b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/spanner-ado-net-benchmarks.csproj new file mode 100644 index 00000000..1c8e0a3d --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/spanner-ado-net-benchmarks.csproj @@ -0,0 +1,19 @@ + + + + Exe + net8.0 + Google.Cloud.Spanner.DataProvider.Benchmarks + enable + enable + Google.Cloud.Spanner.DataProvider.Benchmarks + default + + + + + + + + + diff --git a/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/LastNameGenerator.cs b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/LastNameGenerator.cs new file mode 100644 index 00000000..2179eb61 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/LastNameGenerator.cs @@ -0,0 +1,22 @@ +namespace Google.Cloud.Spanner.DataProvider.Benchmarks.tpcc; + +public static class LastNameGenerator +{ + private static readonly string[] Parts = ["BAR", "OUGHT", "ABLE", "PRI", "PRES", "ESE", "ANTI", "CALLY", "ATION", "EING"]; + + public static string GenerateLastName(long rowIndex) { + int row; + if (rowIndex < 1000L) + { + row = (int) rowIndex; + } + else + { + row = Random.Shared.Next(1000); + } + return Parts[row / 100] + + Parts[row / 10 % 10] + + Parts[row % 10]; + } + +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/Program.cs b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/Program.cs new file mode 100644 index 00000000..62e0d7d1 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/Program.cs @@ -0,0 +1,158 @@ +using System.Collections.Concurrent; +using System.Data.Common; +using System.Diagnostics; +using Google.Cloud.Spanner.Admin.Database.V1; +using Google.Cloud.Spanner.DataProvider.Benchmarks.tpcc.loader; +using Microsoft.AspNetCore.Builder; + +namespace Google.Cloud.Spanner.DataProvider.Benchmarks.tpcc; + +public static class Program +{ + enum ClientType + { + SpannerLib, + NativeSpannerLib, + ClientLib, + } + + public static async Task Main(string[] args) + { + var cancellationTokenSource = new CancellationTokenSource(); + var builder = WebApplication.CreateBuilder(args); + var port = Environment.GetEnvironmentVariable("PORT") ?? "8080"; + var url = $"http://0.0.0.0:{port}"; + var app = builder.Build(); + app.MapGet("/", () => { }); + var webapp = app.RunAsync(url); + + var logWaitTime = int.Parse(Environment.GetEnvironmentVariable("LOG_WAIT_TIME") ?? "10"); + var database = Environment.GetEnvironmentVariable("DATABASE") ?? "projects/appdev-soda-spanner-staging/instances/knut-test-ycsb/databases/dotnet-tpcc"; + var retryAbortsInternally = bool.Parse(Environment.GetEnvironmentVariable("RETRY_ABORTS_INTERNALLY") ?? "true"); + var numWarehouses = int.Parse(Environment.GetEnvironmentVariable("NUM_WAREHOUSES") ?? "10"); + var numClients = int.Parse(Environment.GetEnvironmentVariable("NUM_CLIENTS") ?? "10"); + var targetTps = int.Parse(Environment.GetEnvironmentVariable("TRANSACTIONS_PER_SECOND") ?? "0"); + var clientTypeName = Environment.GetEnvironmentVariable("CLIENT_TYPE") ?? "SpannerLib"; + if (!Enum.TryParse(clientTypeName, out ClientType clientType)) + { + throw new ArgumentException($"Unknown client type: {clientTypeName}"); + } + + var connectionString = $"Data Source={database}"; + if (!retryAbortsInternally) + { + connectionString += ";retryAbortsInternally=false"; + } + await using (var connection = new SpannerConnection()) + { + connection.ConnectionString = connectionString; + await connection.OpenAsync(cancellationTokenSource.Token); + + Console.WriteLine("Creating schema..."); + await SchemaUtil.CreateSchemaAsync(connection, DatabaseDialect.Postgresql, cancellationTokenSource.Token); + + Console.WriteLine("Loading data..."); + var loader = new DataLoader(connection, numWarehouses); + await loader.LoadAsync(cancellationTokenSource.Token); + } + + Console.WriteLine("Running benchmark..."); + var stats = new Stats(); + + if (targetTps > 0) + { + var maxWaitTime = 2 * 1000 / targetTps; + Console.WriteLine($"Clients: {numClients}"); + Console.WriteLine($"Transactions per second: {targetTps}"); + Console.WriteLine($"Max wait time: {maxWaitTime}"); + var runners = new BlockingCollection(); + for (var client = 0; client < numClients; client++) + { + runners.Add(await CreateRunnerAsync(clientType, connectionString, stats, numWarehouses, cancellationTokenSource), cancellationTokenSource.Token); + } + var lastLogTime = DateTime.UtcNow; + while (!cancellationTokenSource.IsCancellationRequested) + { + var randomWaitTime = Random.Shared.Next(0, maxWaitTime); + var stopwatch = Stopwatch.StartNew(); + if (runners.TryTake(out var runner, 20_000, cancellationTokenSource.Token)) + { + var source = new CancellationTokenSource(); + source.CancelAfter(TimeSpan.FromSeconds(10)); + var token = source.Token; + stats.RegisterTransactionStarted(); + var task = runner!.RunTransactionAsync(token); + _ = task.ContinueWith(_ => + { + stats.RegisterTransactionCompleted(); + runners.Add(runner, cancellationTokenSource.Token); + task.Dispose(); + }, TaskContinuationOptions.ExecuteSynchronously); + } + else + { + await Console.Error.WriteLineAsync("No runner available"); + } + randomWaitTime -= (int) stopwatch.ElapsedMilliseconds; + if (randomWaitTime > 0) + { + await Task.Delay(TimeSpan.FromMilliseconds(randomWaitTime), cancellationTokenSource.Token); + } + if ((DateTime.UtcNow - lastLogTime).TotalSeconds >= logWaitTime) + { + Console.WriteLine($"Num available runners: {runners.Count}"); + Console.WriteLine($"Thread pool size: {ThreadPool.ThreadCount}"); + stats.LogStats(); + lastLogTime = DateTime.UtcNow; + } + } + } + else + { + var tasks = new List(); + for (var client = 0; client < numClients; client++) + { + var runner = await CreateRunnerAsync(clientType, connectionString, stats, numWarehouses, cancellationTokenSource); + tasks.Add(runner.RunAsync(cancellationTokenSource.Token)); + } + while (!cancellationTokenSource.Token.IsCancellationRequested) + { + await Task.Delay(TimeSpan.FromSeconds(logWaitTime), cancellationTokenSource.Token); + stats.LogStats(); + } + await Task.WhenAll(tasks); + } + + await app.StopAsync(cancellationTokenSource.Token); + await webapp; + } + + private static async Task CreateRunnerAsync( + ClientType clientType, + string connectionString, + Stats stats, + int numWarehouses, + CancellationTokenSource cancellationTokenSource) + { + DbConnection connection; + if (clientType == ClientType.SpannerLib) + { + connection = new SpannerConnection(); + } + else if (clientType == ClientType.NativeSpannerLib) + { + connection = new SpannerConnection {UseNativeLibrary = true}; + } + else if (clientType == ClientType.ClientLib) + { + connection = new Google.Cloud.Spanner.Data.SpannerConnection(); + } + else + { + throw new ArgumentException($"Unknown client type: {clientType}"); + } + connection.ConnectionString = connectionString; + await connection.OpenAsync(cancellationTokenSource.Token); + return new TpccRunner(stats, connection, numWarehouses); + } +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/RowNotFoundException.cs b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/RowNotFoundException.cs new file mode 100644 index 00000000..f1b3530b --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/RowNotFoundException.cs @@ -0,0 +1,5 @@ +using System.Data.Common; + +namespace Google.Cloud.Spanner.DataProvider.Benchmarks.tpcc; + +public class RowNotFoundException(string message) : DbException(message); \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/SchemaDefinition.cs b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/SchemaDefinition.cs new file mode 100644 index 00000000..dff383ea --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/SchemaDefinition.cs @@ -0,0 +1,200 @@ +namespace Google.Cloud.Spanner.DataProvider.Benchmarks.tpcc; + +public static class SchemaDefinition +{ + public const string CreateTablesPostgreSql = @" +START BATCH DDL; + +CREATE TABLE IF NOT EXISTS warehouse ( + w_id int not null, + w_name varchar(10), + w_street_1 varchar(20), + w_street_2 varchar(20), + w_city varchar(20), + w_state varchar(2), + w_zip varchar(9), + w_tax decimal, + w_ytd decimal, + primary key (w_id) +); + +create table IF NOT EXISTS district ( + d_id int not null, + w_id int not null, + d_name varchar(10), + d_street_1 varchar(20), + d_street_2 varchar(20), + d_city varchar(20), + d_state varchar(2), + d_zip varchar(9), + d_tax decimal, + d_ytd decimal, + d_next_o_id int, + primary key (w_id, d_id) +); + +-- CUSTOMER TABLE + +create table IF NOT EXISTS customer ( + c_id int not null, + d_id int not null, + w_id int not null, + c_first varchar(16), + c_middle varchar(2), + c_last varchar(16), + c_street_1 varchar(20), + c_street_2 varchar(20), + c_city varchar(20), + c_state varchar(2), + c_zip varchar(9), + c_phone varchar(16), + c_since timestamptz, + c_credit varchar(2), + c_credit_lim bigint, + c_discount decimal, + c_balance decimal, + c_ytd_payment decimal, + c_payment_cnt int, + c_delivery_cnt int, + c_data text, + PRIMARY KEY(w_id, d_id, c_id) +); + +-- HISTORY TABLE + +create table IF NOT EXISTS history ( + c_id int, + d_id int, + w_id int, + h_d_id int, + h_w_id int, + h_date timestamptz, + h_amount decimal, + h_data varchar(24), + PRIMARY KEY(c_id, d_id, w_id, h_d_id, h_w_id, h_date) +); + +create table IF NOT EXISTS orders ( + o_id int not null, + d_id int not null, + w_id int not null, + c_id int not null, + o_entry_d timestamptz, + o_carrier_id int, + o_ol_cnt int, + o_all_local int, + PRIMARY KEY(w_id, d_id, c_id, o_id) +); + +-- NEW_ORDER table + +create table IF NOT EXISTS new_orders ( + o_id int not null, + c_id int not null, + d_id int not null, + w_id int not null, + PRIMARY KEY(w_id, d_id, o_id, c_id) +); + +create table IF NOT EXISTS order_line ( + o_id int not null, + c_id int not null, + d_id int not null, + w_id int not null, + ol_number int not null, + ol_i_id int, + ol_supply_w_id int, + ol_delivery_d timestamptz, + ol_quantity int, + ol_amount decimal, + ol_dist_info varchar(24), + PRIMARY KEY(w_id, d_id, o_id, c_id, ol_number) +); + +-- STOCK table + +create table IF NOT EXISTS stock ( + s_i_id int not null, + w_id int not null, + s_quantity int, + s_dist_01 varchar(24), + s_dist_02 varchar(24), + s_dist_03 varchar(24), + s_dist_04 varchar(24), + s_dist_05 varchar(24), + s_dist_06 varchar(24), + s_dist_07 varchar(24), + s_dist_08 varchar(24), + s_dist_09 varchar(24), + s_dist_10 varchar(24), + s_ytd decimal, + s_order_cnt int, + s_remote_cnt int, + s_data varchar(50), + PRIMARY KEY(w_id, s_i_id) +); + +create table IF NOT EXISTS item ( + i_id int not null, + i_im_id int, + i_name varchar(24), + i_price decimal, + i_data varchar(50), + PRIMARY KEY(i_id) +); + +CREATE INDEX idx_customer ON customer (w_id,d_id,c_last,c_first); +CREATE INDEX idx_orders ON orders (w_id,d_id,o_id); +CREATE INDEX fkey_stock_2 ON stock (s_i_id); +CREATE INDEX fkey_order_line_2 ON order_line (ol_supply_w_id,ol_i_id); +CREATE INDEX fkey_history_1 ON history (w_id,d_id,c_id); +CREATE INDEX fkey_history_2 ON history (h_w_id,h_d_id ); + +ALTER TABLE new_orders ADD CONSTRAINT fkey_new_orders_1_ FOREIGN KEY(w_id,d_id,c_id,o_id) REFERENCES orders(w_id,d_id,c_id,o_id); +ALTER TABLE orders ADD CONSTRAINT fkey_orders_1_ FOREIGN KEY(w_id,d_id,c_id) REFERENCES customer(w_id,d_id,c_id); +ALTER TABLE customer ADD CONSTRAINT fkey_customer_1_ FOREIGN KEY(w_id,d_id) REFERENCES district(w_id,d_id); +ALTER TABLE history ADD CONSTRAINT fkey_history_1_ FOREIGN KEY(w_id,d_id,c_id) REFERENCES customer(w_id,d_id,c_id); +ALTER TABLE history ADD CONSTRAINT fkey_history_2_ FOREIGN KEY(h_w_id,h_d_id) REFERENCES district(w_id,d_id); +ALTER TABLE district ADD CONSTRAINT fkey_district_1_ FOREIGN KEY(w_id) REFERENCES warehouse(w_id); +ALTER TABLE order_line ADD CONSTRAINT fkey_order_line_1_ FOREIGN KEY(w_id,d_id,c_id,o_id) REFERENCES orders(w_id,d_id,c_id,o_id); +ALTER TABLE order_line ADD CONSTRAINT fkey_order_line_2_ FOREIGN KEY(ol_supply_w_id,ol_i_id) REFERENCES stock(w_id,s_i_id); +ALTER TABLE stock ADD CONSTRAINT fkey_stock_1_ FOREIGN KEY(w_id) REFERENCES warehouse(w_id); +ALTER TABLE stock ADD CONSTRAINT fkey_stock_2_ FOREIGN KEY(s_i_id) REFERENCES item(i_id); + +RUN BATCH; +"; + + public const string DropTables = @" +start batch ddl; + +drop index if exists fkey_history_2; +drop index if exists fkey_history_1; +drop index if exists fkey_order_line_2; +drop index if exists fkey_stock_2; +drop index if exists idx_orders; +drop index if exists idx_customer; + +ALTER TABLE new_orders DROP CONSTRAINT fkey_new_orders_1_; +ALTER TABLE orders DROP CONSTRAINT fkey_orders_1_; +ALTER TABLE customer DROP CONSTRAINT fkey_customer_1_; +ALTER TABLE history DROP CONSTRAINT fkey_history_1_; +ALTER TABLE history DROP CONSTRAINT fkey_history_2_; +ALTER TABLE district DROP CONSTRAINT fkey_district_1_; +ALTER TABLE order_line DROP CONSTRAINT fkey_order_line_1_; +ALTER TABLE order_line DROP CONSTRAINT fkey_order_line_2_; +ALTER TABLE stock DROP CONSTRAINT fkey_stock_1_; +ALTER TABLE stock DROP CONSTRAINT fkey_stock_2_; + +drop table if exists new_orders; +drop table if exists order_line; +drop table if exists history; +drop table if exists orders; +drop table if exists stock; +drop table if exists customer; +drop table if exists district; +drop table if exists warehouse; +drop table if exists item; + +run batch; +"; +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/SchemaUtil.cs b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/SchemaUtil.cs new file mode 100644 index 00000000..49acdd8e --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/SchemaUtil.cs @@ -0,0 +1,47 @@ +using Google.Cloud.Spanner.Admin.Database.V1; + +namespace Google.Cloud.Spanner.DataProvider.Benchmarks.tpcc; + +static class SchemaUtil +{ + internal static async Task CreateSchemaAsync(SpannerConnection connection, DatabaseDialect dialect, CancellationToken cancellationToken) + { + await using var cmd = connection.CreateCommand(); + cmd.CommandText = "select count(1) " + + "from information_schema.tables " + + "where " + + (dialect == DatabaseDialect.Postgresql ? "table_schema='public' and " : "table_schema='' and ") + + "table_name in ('warehouse', 'district', 'customer', 'history', 'orders', 'new_orders', 'order_line', 'stock', 'item')"; + var count = await cmd.ExecuteScalarAsync(cancellationToken); + if (count is long and 9) + { + return; + } + + var commands = SchemaDefinition.CreateTablesPostgreSql.Split(";"); + foreach (var command in commands) + { + if (command.Trim() == "") + { + continue; + } + cmd.CommandText = command; + await cmd.ExecuteNonQueryAsync(cancellationToken); + } + } + + internal static async Task DropSchemaAsync(SpannerConnection connection, CancellationToken cancellationToken) + { + await using var cmd = connection.CreateCommand(); + var commands = SchemaDefinition.DropTables.Split(";"); + foreach (var command in commands) + { + if (command.Trim() == "") + { + continue; + } + cmd.CommandText = command; + await cmd.ExecuteNonQueryAsync(cancellationToken); + } + } +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/Stats.cs b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/Stats.cs new file mode 100644 index 00000000..56ab5e10 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/Stats.cs @@ -0,0 +1,97 @@ +namespace Google.Cloud.Spanner.DataProvider.Benchmarks.tpcc; + +internal class Stats +{ + private readonly DateTime _startTime; + private ulong _numTransactions; + private ulong _numTransactionsStarted; + private ulong _numTransactionsCompleted; + private ulong _numFailedTransactions; + private ulong _numNewOrderTransactions; + private ulong _numPaymentTransactions; + private ulong _numOrderStatusTransactions; + private ulong _numDeliveryTransactions; + private ulong _numStockLevelTransactions; + private Exception? _lastException; + + private ulong _totalMillis; + + internal Stats() + { + _startTime = DateTime.UtcNow; + } + + internal void RegisterTransactionStarted() + { + Interlocked.Increment(ref _numTransactionsStarted); + } + + internal void RegisterTransactionCompleted() + { + Interlocked.Increment(ref _numTransactionsCompleted); + } + + internal void RegisterTransaction(TpccRunner.TransactionType transactionType, TimeSpan duration) + { + Interlocked.Increment(ref _numTransactions); + Interlocked.Add(ref _totalMillis, (ulong) duration.TotalMilliseconds); + switch (transactionType) + { + case TpccRunner.TransactionType.NewOrder: + Interlocked.Increment(ref _numNewOrderTransactions); + break; + case TpccRunner.TransactionType.Payment: + Interlocked.Increment(ref _numPaymentTransactions); + break; + case TpccRunner.TransactionType.OrderStatus: + Interlocked.Increment(ref _numOrderStatusTransactions); + break; + case TpccRunner.TransactionType.Delivery: + Interlocked.Increment(ref _numDeliveryTransactions); + break; + case TpccRunner.TransactionType.StockLevel: + Interlocked.Increment(ref _numStockLevelTransactions); + break; + default: + throw new ArgumentOutOfRangeException(nameof(transactionType), transactionType, null); + } + } + + internal void RegisterFailedTransaction(TpccRunner.TransactionType transactionType, TimeSpan duration, Exception error) + { + Interlocked.Increment(ref _numFailedTransactions); + lock (this) + { + _lastException = error; + } + } + + internal void LogStats() + { + lock (this) + { + if (_lastException != null) + { + Console.Error.WriteLine(_lastException); + _lastException = null; + } + } + Console.Write(ToString()); + } + + public override string ToString() + { + return $" Total duration: {DateTime.UtcNow - _startTime}{Environment.NewLine}" + + $"Transactions/sec: {Interlocked.Read(ref _numTransactions) / (DateTime.UtcNow - _startTime).TotalSeconds}{Environment.NewLine}" + + $" Total: {Interlocked.Read(ref _numTransactions)}{Environment.NewLine}" + + $" Avg: {Interlocked.Read(ref _totalMillis) / Interlocked.Read(ref _numTransactions)}{Environment.NewLine}" + + $" Started: {Interlocked.Read(ref _numTransactionsStarted)}{Environment.NewLine}" + + $" Completed: {Interlocked.Read(ref _numTransactionsCompleted)}{Environment.NewLine}" + + $" Failed: {Interlocked.Read(ref _numFailedTransactions)}{Environment.NewLine}" + + $" Num new order: {Interlocked.Read(ref _numNewOrderTransactions)}{Environment.NewLine}" + + $" Num payment: {Interlocked.Read(ref _numPaymentTransactions)}{Environment.NewLine}" + + $"Num order status: {Interlocked.Read(ref _numOrderStatusTransactions)}{Environment.NewLine}" + + $" Num delivery: {Interlocked.Read(ref _numDeliveryTransactions)}{Environment.NewLine}" + + $" Num stock level: {Interlocked.Read(ref _numStockLevelTransactions)}{Environment.NewLine}"; + } +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/TpccRunner.cs b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/TpccRunner.cs new file mode 100644 index 00000000..bf753df6 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/TpccRunner.cs @@ -0,0 +1,861 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Data; +using System.Data.Common; +using System.Diagnostics; +using Google.Cloud.Spanner.Data; +using Google.Cloud.Spanner.DataProvider.Benchmarks.tpcc.loader; +using Google.Cloud.Spanner.V1; +using Google.Rpc; +using SpannerException = Google.Cloud.SpannerLib.SpannerException; + +namespace Google.Cloud.Spanner.DataProvider.Benchmarks.tpcc; + +internal class TpccRunner +{ + internal enum TransactionType + { + Unknown, + NewOrder, + Payment, + OrderStatus, + Delivery, + StockLevel, + } + + private readonly Stats _stats; + private readonly DbConnection _connection; + private readonly int _numWarehouses; + private readonly int _numDistrictsPerWarehouse; + private readonly int _numCustomersPerDistrict; + private readonly int _numItems; + private readonly bool _isClientLib; + + private DbTransaction? _currentTransaction; + + internal TpccRunner( + Stats stats, + DbConnection connection, + int numWarehouses, + int numDistrictsPerWarehouse = 10, + int numCustomersPerDistrict = 3000, + int numItems = 100_000) + { + _stats = stats; + _connection = connection; + _numWarehouses = numWarehouses; + _numDistrictsPerWarehouse = numDistrictsPerWarehouse; + _numCustomersPerDistrict = numCustomersPerDistrict; + _numItems = numItems; + _isClientLib = connection is Data.SpannerConnection; + } + + internal async Task RunAsync(CancellationToken cancellationToken) + { + while (!cancellationToken.IsCancellationRequested) + { + await RunTransactionAsync(cancellationToken); + } + } + + internal async Task RunTransactionAsync(CancellationToken cancellationToken) + { + var watch = Stopwatch.StartNew(); + var transaction = Random.Shared.Next(23); + var transactionType = TransactionType.Unknown; + var attempts = 0; + while (true) + { + attempts++; + try + { + if (transaction < 10) + { + transactionType = TransactionType.NewOrder; + await NewOrderAsync(cancellationToken); + } + else if (transaction < 20) + { + transactionType = TransactionType.Payment; + await PaymentAsync(cancellationToken); + } + else if (transaction < 21) + { + transactionType = TransactionType.OrderStatus; + await OrderStatusAsync(cancellationToken); + } + else if (transaction < 22) + { + transactionType = TransactionType.Delivery; + await DeliveryAsync(cancellationToken); + } + else if (transaction < 23) + { + transactionType = TransactionType.StockLevel; + await StockLevelAsync(cancellationToken); + } + else + { + throw new ArgumentException($"Invalid transaction type {transaction}"); + } + + _stats.RegisterTransaction(transactionType, watch.Elapsed); + break; + } + catch (Exception exception) + { + await SilentRollbackTransactionAsync(cancellationToken); + if (attempts < 10) + { + if (exception is SpannerException { Code: Code.Aborted }) + { + continue; + } + + if (exception is Data.SpannerException { ErrorCode: ErrorCode.Aborted }) + { + continue; + } + } + else + { + await Console.Error.WriteLineAsync($"Giving up after {attempts} attempts"); + } + + _stats.RegisterFailedTransaction(transactionType, watch.Elapsed, exception); + break; + } + finally + { + if (_currentTransaction != null) + { + await Console.Error.WriteLineAsync("Transaction still open!"); + await _currentTransaction.DisposeAsync(); + _currentTransaction = null; + } + } + } + } + + private async Task NewOrderAsync(CancellationToken cancellationToken) + { + var warehouseId = DataLoader.ReverseBitsUnsigned((ulong)Random.Shared.Next(_numWarehouses)); + var districtId = DataLoader.ReverseBitsUnsigned((ulong)Random.Shared.Next(_numDistrictsPerWarehouse)); + var customerId = DataLoader.ReverseBitsUnsigned((ulong)Random.Shared.Next(_numCustomersPerDistrict)); + + var orderLineCount = Random.Shared.Next(5, 16); + var itemIds = new long[orderLineCount]; + var supplyWarehouses = new long[orderLineCount]; + var quantities = new int[orderLineCount]; + var rollback = Random.Shared.Next(100); + var allLocal = 1; + + for (var line = 0; line < orderLineCount; line++) + { + if (rollback == 1 && line == orderLineCount - 1) + { + itemIds[line] = DataLoader.ReverseBitsUnsigned(long.MaxValue); + } + else + { + // TODO: Make sure that the chosen item IDs are unique. + itemIds[line] = DataLoader.ReverseBitsUnsigned((ulong)Random.Shared.Next(_numItems)); + } + + if (Random.Shared.Next(100) == 50) + { + supplyWarehouses[line] = GetOtherWarehouseId(warehouseId); + allLocal = 0; + } + else + { + supplyWarehouses[line] = warehouseId; + } + + quantities[line] = Random.Shared.Next(1, 10); + } + + await BeginTransactionAsync("new_order", cancellationToken); + + // TODO: These queries can run in parallel. + var row = await ExecuteRowAsync( + "SELECT c_discount, c_last, c_credit, w_tax " + + "FROM customer c, warehouse w " + + "WHERE w.w_id = $1 AND c.w_id = w.w_id AND c.d_id = $2 AND c.c_id = $3 " + + "FOR UPDATE", cancellationToken, + warehouseId, districtId, customerId); + var discount = ToDecimal(row[0]); + var last = (string)row[1]; + var credit = (string)row[2]; + var warehouseTax = ToDecimal(row[3]); + + row = await ExecuteRowAsync( + "SELECT d_next_o_id, d_tax " + + "FROM district " + + "WHERE w_id = $1 AND d_id = $2 FOR UPDATE", cancellationToken, + warehouseId, districtId); + var districtNextOrderId = row[0] is DBNull ? 0L : (long)row[0]; + var districtTax = ToDecimal(row[1]); + + object batch = _isClientLib ? (_currentTransaction as Data.SpannerTransaction)!.CreateBatchDmlCommand() : _connection.CreateBatch(); + CreateBatchCommand( + batch, + "UPDATE district SET d_next_o_id = $1 WHERE d_id = $2 AND w_id= $3", + districtNextOrderId + 1L, districtId, warehouseId); + CreateBatchCommand( + batch, + "INSERT INTO orders (o_id, d_id, w_id, c_id, o_entry_d, o_ol_cnt, o_all_local) " + + "VALUES ($1,$2,$3,$4,CURRENT_TIMESTAMP,$5,$6)", + districtNextOrderId, districtId, warehouseId, customerId, orderLineCount, allLocal); + CreateBatchCommand( + batch, + "INSERT INTO new_orders (o_id, c_id, d_id, w_id) VALUES ($1,$2,$3,$4)", + districtNextOrderId, customerId, districtId, warehouseId); + + for (var line = 0; line < orderLineCount; line++) + { + var orderLineSupplyWarehouseId = supplyWarehouses[line]; + var orderLineItemId = itemIds[line]; + var orderLineQuantity = quantities[line]; + try + { + row = await ExecuteRowAsync( + "SELECT i_price, i_name, i_data FROM item WHERE i_id = $1", + cancellationToken, + orderLineItemId); + } + catch (RowNotFoundException) + { + // TODO: Record deliberate rollback + await RollbackTransactionAsync(cancellationToken); + return; + } + + var itemPrice = ToDecimal(row[0]); + var itemName = (string)row[1]; + var itemData = (string)row[2]; + + row = await ExecuteRowAsync( + "SELECT s_quantity, s_data, s_dist_01, s_dist_02, s_dist_03, s_dist_04, s_dist_05, s_dist_06, s_dist_07, s_dist_08, s_dist_09, s_dist_10 " + + "FROM stock " + + "WHERE s_i_id = $1 AND w_id= $2 FOR UPDATE", + cancellationToken, + orderLineItemId, orderLineSupplyWarehouseId); + var stockQuantity = (long)row[0]; + var stockData = (string)row[1]; + var stockDistrict = new string[10]; + for (int i = 2; i < stockDistrict.Length + 2; i++) + { + stockDistrict[i - 2] = (string)row[i]; + } + + var orderLineDistrictInfo = + stockDistrict[(int)(DataLoader.ReverseBitsUnsigned((ulong)districtId) % stockDistrict.Length)]; + if (stockQuantity > orderLineQuantity) + { + stockQuantity = stockQuantity - orderLineQuantity; + } + else + { + stockQuantity = stockQuantity - orderLineQuantity + 91; + } + + CreateBatchCommand(batch, "UPDATE stock SET s_quantity=$1 WHERE s_i_id=$2 AND w_id=$3", + stockQuantity, orderLineItemId, orderLineSupplyWarehouseId); + + var totalTax = 1m + warehouseTax + districtTax; + var discountFactor = 1m - discount; + var orderLineAmount = orderLineQuantity * itemPrice * totalTax * discountFactor; + CreateBatchCommand(batch, + "INSERT INTO order_line (o_id, c_id, d_id, w_id, ol_number, ol_i_id, ol_supply_w_id, ol_quantity, ol_amount, ol_dist_info) " + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)", + districtNextOrderId, + customerId, + districtId, + warehouseId, + line, + orderLineItemId, + orderLineSupplyWarehouseId, + orderLineQuantity, + orderLineAmount, + orderLineDistrictInfo); + } + + if (batch is Data.SpannerBatchCommand spannerBatchCommand) + { + await spannerBatchCommand.ExecuteNonQueryAsync(cancellationToken); + } + else if (batch is DbBatch dbBatch) + { + await dbBatch.ExecuteNonQueryAsync(cancellationToken); + } + else + { + throw new NotSupportedException("Batch type not supported"); + } + await CommitTransactionAsync(cancellationToken); + } + + private async Task PaymentAsync(CancellationToken cancellationToken) + { + var warehouseId = DataLoader.ReverseBitsUnsigned((ulong)Random.Shared.Next(_numWarehouses)); + var districtId = DataLoader.ReverseBitsUnsigned((ulong)Random.Shared.Next(_numDistrictsPerWarehouse)); + var customerId = DataLoader.ReverseBitsUnsigned((ulong)Random.Shared.Next(_numCustomersPerDistrict)); + var amount = Random.Shared.Next(1, 500000) / 100m; + + long customerWarehouseId; + long customerDistrictId; + var lastName = LastNameGenerator.GenerateLastName(long.MaxValue); + bool byName; + object[] row; + if (Random.Shared.Next(100) < 60) + { + byName = true; + } + else + { + byName = false; + } + if (Random.Shared.Next(100) < 85) + { + customerWarehouseId = warehouseId; + customerDistrictId = districtId; + } + else + { + customerWarehouseId = GetOtherWarehouseId(warehouseId); + customerDistrictId = DataLoader.ReverseBitsUnsigned((ulong)Random.Shared.Next(_numDistrictsPerWarehouse)); + } + await BeginTransactionAsync("payment", cancellationToken); + await ExecuteNonQueryAsync("UPDATE warehouse SET w_ytd = w_ytd + $1 WHERE w_id = $2", + cancellationToken, amount, warehouseId); + + row = await ExecuteRowAsync( + "SELECT w_street_1, w_street_2, w_city, w_state, w_zip, w_name " + + "FROM warehouse " + + "WHERE w_id = $1", + cancellationToken, warehouseId); + var warehouseStreet1 = (string) row[0]; + var warehouseStreet2 = (string) row[1]; + var warehouseCity = (string) row[2]; + var warehouseState = (string) row[3]; + var warehouseZip = (string) row[4]; + var warehouseName = (string) row[5]; + + await ExecuteNonQueryAsync( + "UPDATE district SET d_ytd = d_ytd + $1 WHERE w_id = $2 AND d_id= $3", + cancellationToken, amount, warehouseId, districtId); + + row = await ExecuteRowAsync( + "SELECT d_street_1, d_street_2, d_city, d_state, d_zip, d_name " + + "FROM district " + + "WHERE w_id = $1 AND d_id = $2", + cancellationToken, warehouseId, districtId); + var districtStreet1 = (string) row[0]; + var districtStreet2 = (string) row[1]; + var districtCity = (string) row[2]; + var districtState = (string) row[3]; + var districtZip = (string) row[4]; + var districtName = (string) row[5]; + + if (byName) + { + row = await ExecuteRowAsync( + "SELECT count(c_id) namecnt " + + "FROM customer " + + "WHERE w_id = $1 AND d_id= $2 AND c_last=$3", + cancellationToken, customerWarehouseId, customerDistrictId, lastName); + var nameCount = (int) (long) row[0]; + if (nameCount % 2 == 0) + { + nameCount++; + } + var resultSet = await ExecuteQueryAsync( + "SELECT c_id " + + "FROM customer " + + "WHERE w_id=$1 AND d_id=$2 AND c_last=$3 " + + "ORDER BY c_first", + cancellationToken, customerWarehouseId, customerDistrictId, lastName); + for (var counter = 0; counter < Math.Min(nameCount, resultSet.Count); counter++) + { + customerId = (long) resultSet[counter][0]; + } + } + row = await ExecuteRowAsync( + "SELECT c_first, c_middle, c_last, c_street_1, c_street_2, c_city, c_state, c_zip, c_phone, c_credit, c_credit_lim, c_discount, c_balance, c_ytd_payment, c_since " + + "FROM customer " + + "WHERE w_id=$1 AND d_id=$2 AND c_id=$3 FOR UPDATE", + cancellationToken, customerWarehouseId, customerDistrictId, customerId); + var firstName = (string) row[0]; + var middleName = (string) row[1]; + lastName = (string) row[2]; + var street1 = (string) row[3]; + var street2 = (string) row[4]; + var city = (string) row[5]; + var state = (string) row[6]; + var zip = (string) row[7]; + var phone = (string) row[8]; + var credit = (string) row[9]; + var creditLimit = (long) row[10]; + var discount = ToDecimal(row[11]); + var balance = ToDecimal(row[12]); + var ytdPayment = ToDecimal(row[13]); + var since = (DateTime) row[14]; + + // TODO: Use batching from here + balance = balance - amount; + ytdPayment = ytdPayment + amount; + if ("BC".Equals(credit)) + { + row = await ExecuteRowAsync( + "SELECT c_data FROM customer WHERE w_id=$1 AND d_id=$2 AND c_id=$3", + cancellationToken, customerWarehouseId, customerDistrictId, customerId); + var customerData = (string)row[0]; + var newCustomerData = + $"| {customerId} {customerDistrictId} {customerWarehouseId} {districtId} {warehouseId} {amount} {DateTime.Now} {customerData}"; + if (newCustomerData.Length > 500) + { + newCustomerData = newCustomerData.Substring(0, 500); + } + await ExecuteNonQueryAsync( + "UPDATE customer " + + "SET c_balance=$1, c_ytd_payment=$2, c_data=$3 " + + "WHERE w_id = $4 AND d_id=$5 AND c_id=$6", + cancellationToken, + balance, + ytdPayment, + newCustomerData, + customerWarehouseId, + customerDistrictId, + customerId + ); + } + else + { + await ExecuteNonQueryAsync( + "UPDATE customer " + + "SET c_balance=$1, c_ytd_payment=$2 " + + "WHERE w_id = $3 AND d_id=$4 AND c_id=$5", + cancellationToken, balance, ytdPayment, customerWarehouseId, customerDistrictId, customerId); + } + + var data = $"{warehouseName} {districtName}"; + if (data.Length > 24) + { + data = data.Substring(0, 24); + } + await ExecuteNonQueryAsync( + "INSERT INTO history (d_id, w_id, c_id, h_d_id, h_w_id, h_date, h_amount, h_data) " + + "VALUES ($1,$2,$3,$4,$5,CURRENT_TIMESTAMP,$6,$7)", + cancellationToken, + customerDistrictId, + customerWarehouseId, + customerId, + districtId, + warehouseId, + amount, + data); + + await CommitTransactionAsync(cancellationToken); + } + + private async Task OrderStatusAsync(CancellationToken cancellationToken) + { + var warehouseId = DataLoader.ReverseBitsUnsigned((ulong)Random.Shared.Next(_numWarehouses)); + var districtId = DataLoader.ReverseBitsUnsigned((ulong)Random.Shared.Next(_numDistrictsPerWarehouse)); + var customerId = DataLoader.ReverseBitsUnsigned((ulong)Random.Shared.Next(_numCustomersPerDistrict)); + + var lastName = LastNameGenerator.GenerateLastName(long.MaxValue); + object[] row; + var byName = Random.Shared.Next(100) < 60; + + decimal balance; + string first, middle, last; + + await BeginTransactionAsync("order_status", cancellationToken); + if (byName) + { + row = await ExecuteRowAsync( + "SELECT count(c_id) namecnt " + + "FROM customer " + + "WHERE w_id=$1 AND d_id=$2 AND c_last=$3", + cancellationToken, warehouseId, districtId, lastName); + int nameCount = (int) (long) row[0]; + if (nameCount % 2 == 0) + { + nameCount++; + } + var resultSet = await ExecuteQueryAsync( + "SELECT c_balance, c_first, c_middle, c_id " + + "FROM customer WHERE w_id = $1 AND d_id=$2 AND c_last=$3 " + + "ORDER BY c_first", + cancellationToken, warehouseId, districtId, lastName); + for (int counter = 0; counter < Math.Min(nameCount, resultSet.Count); counter++) + { + balance = ToDecimal(resultSet[counter][0]); + first = (string) resultSet[counter][1]; + middle = (string) resultSet[counter][2]; + customerId = (long) resultSet[counter][3]; + } + } + else + { + row = await ExecuteRowAsync( + "SELECT c_balance, c_first, c_middle, c_last " + + "FROM customer " + + "WHERE w_id = $1 AND d_id=$2 AND c_id=$3", + cancellationToken, warehouseId, districtId, customerId); + balance = ToDecimal(row[0]); + first = (string) row[1]; + middle = (string) row[2]; + last = (string) row[3]; + } + + var maybeRow = await ExecuteRowAsync(false, + "SELECT o_id, o_carrier_id, o_entry_d " + + "FROM orders " + + "WHERE w_id = $1 AND d_id = $2 AND c_id = $3 " + + "ORDER BY o_id DESC", + cancellationToken, warehouseId, districtId, customerId); + var orderId = maybeRow == null ? 0L : (long) maybeRow[0]; + + long item_id, supply_warehouse_id, quantity; + decimal amount; + DateTime? delivery_date; + var results = await ExecuteQueryAsync( + "SELECT ol_i_id, ol_supply_w_id, ol_quantity, ol_amount, ol_delivery_d " + + "FROM order_line " + + "WHERE w_id = $1 AND d_id = $2 AND o_id = $3", + cancellationToken, warehouseId, districtId, orderId); + for (var counter = 0; counter < results.Count; counter++) + { + item_id = (long) results[counter][0]; // item_id + supply_warehouse_id = (long) results[counter][1]; // supply_warehouse_id + quantity = (long) results[counter][2]; // quantity + amount = ToDecimal(results[counter][3]); // amount + delivery_date = results[counter][4] is DBNull ? null : (DateTime) results[counter][4]; // delivery_date + } + await CommitTransactionAsync(cancellationToken); + } + + private async Task DeliveryAsync(CancellationToken cancellationToken) + { + var warehouseId = DataLoader.ReverseBitsUnsigned((ulong)Random.Shared.Next(_numWarehouses)); + var carrierId = Random.Shared.Next(10); + + await BeginTransactionAsync("delivery", cancellationToken); + + for (var district = 0L; district < _numDistrictsPerWarehouse; district++) + { + var districtId = DataLoader.ReverseBitsUnsigned((ulong)district); + var row = await ExecuteRowAsync(false, + "SELECT o_id, c_id " + + "FROM new_orders " + + "WHERE d_id = $1 AND w_id = $2 " + + "ORDER BY o_id ASC " + + "LIMIT 1 FOR UPDATE", + cancellationToken, districtId, warehouseId); + if (row != null) + { + var newOrderId = (long)row[0]; + var customerId = (long)row[1]; + await ExecuteNonQueryAsync( + "DELETE " + + "FROM new_orders " + + "WHERE o_id = $1 AND c_id = $2 AND d_id = $3 AND w_id = $4", + cancellationToken, newOrderId, customerId, districtId, warehouseId); + row = await ExecuteRowAsync( + "SELECT c_id FROM orders WHERE o_id = $1 AND d_id = $2 AND w_id = $3", + cancellationToken, newOrderId, districtId, warehouseId); + await ExecuteNonQueryAsync( + "UPDATE orders " + + "SET o_carrier_id = $1 " + + "WHERE o_id = $2 AND c_id = $3 AND d_id = $4 AND w_id = $5", + cancellationToken, carrierId, newOrderId, customerId, districtId, warehouseId); + await ExecuteNonQueryAsync( + "UPDATE order_line " + + "SET ol_delivery_d = CURRENT_TIMESTAMP " + + "WHERE o_id = $1 AND c_id = $2 AND d_id = $3 AND w_id = $4", + cancellationToken, newOrderId, customerId, districtId, warehouseId); + row = await ExecuteRowAsync( + "SELECT SUM(ol_amount) sm " + + "FROM order_line " + + "WHERE o_id = $1 AND c_id = $2 AND d_id = $3 AND w_id = $4", + cancellationToken, newOrderId, customerId, districtId, warehouseId); + var sumOrderLineAmount = ToDecimal(row[0]); + await ExecuteNonQueryAsync( + "UPDATE customer " + + "SET c_balance = c_balance + $1, c_delivery_cnt = c_delivery_cnt + 1 " + + "WHERE c_id = $2 AND d_id = $3 AND w_id = $4", + cancellationToken, sumOrderLineAmount, customerId, districtId, warehouseId); + } + } + await CommitTransactionAsync(cancellationToken); + } + + private async Task StockLevelAsync(CancellationToken cancellationToken) + { + var warehouseId = DataLoader.ReverseBitsUnsigned((ulong)Random.Shared.Next(_numWarehouses)); + var districtId = DataLoader.ReverseBitsUnsigned((ulong)Random.Shared.Next(_numDistrictsPerWarehouse)); + var level = Random.Shared.Next(10, 21); + + await BeginTransactionAsync("stock_level", cancellationToken); + String stockLevelQueries = "case1"; + Object[] row; + + row = await ExecuteRowAsync( + "SELECT d_next_o_id FROM district WHERE d_id = $1 AND w_id= $2", + cancellationToken, districtId, warehouseId); + var nextOrderId = row[0] is DBNull ? 0L : (long) row[0]; + var resultSet = + await ExecuteQueryAsync( + "SELECT COUNT(DISTINCT (s_i_id)) " + + "FROM order_line ol, stock s " + + "WHERE ol.w_id = $1 " + + "AND ol.d_id = $2 " + + "AND ol.o_id < $3 " + + "AND ol.o_id >= $4 " + + "AND s.w_id= $5 " + + "AND s_i_id=ol_i_id " + + "AND s_quantity < $6", + cancellationToken, + warehouseId, districtId, nextOrderId, nextOrderId - 20, warehouseId, level); + for (var counter = 0; counter < resultSet.Count; counter++) { + var orderLineItemId = (long) resultSet[counter][0]; + row = await ExecuteRowAsync( + "SELECT count(1) FROM stock " + + "WHERE w_id = $1 AND s_i_id = $2 " + + "AND s_quantity < $3", + cancellationToken, warehouseId, orderLineItemId, level); + var stockCount = (long) row[0]; + } + + await CommitTransactionAsync(cancellationToken); + } + + private decimal ToDecimal(object value) + { + return _isClientLib ? ((PgNumeric) value).ToDecimal(LossOfPrecisionHandling.Truncate) : (decimal) value; + } + + private async Task BeginTransactionAsync(string tag, CancellationToken cancellationToken = default) + { + if (_connection is Data.SpannerConnection spannerConnection) + { + _currentTransaction = await spannerConnection.BeginTransactionAsync( + SpannerTransactionCreationOptions.ReadWrite.WithIsolationLevel(IsolationLevel.RepeatableRead), + new SpannerTransactionOptions + { + Tag = tag, + }, + cancellationToken); + } + else if (_connection is SpannerConnection connection) + { + _currentTransaction = await connection.BeginTransactionAsync(IsolationLevel.RepeatableRead, cancellationToken); + await ExecuteNonQueryAsync($"set local transaction_tag = '{tag}'", cancellationToken); + } + } + + private async Task CommitTransactionAsync(CancellationToken cancellationToken = default) + { + if (_currentTransaction != null) + { + await _currentTransaction.CommitAsync(cancellationToken); + await _currentTransaction.DisposeAsync(); + _currentTransaction = null; + } + } + + private async Task SilentRollbackTransactionAsync(CancellationToken cancellationToken = default) + { + try + { + if (_currentTransaction != null) + { + await RollbackTransactionAsync(cancellationToken); + } + else + { + await ExecuteNonQueryAsync("rollback", cancellationToken); + } + } + catch (Exception) + { + if (_currentTransaction != null) + { + await _currentTransaction.DisposeAsync(); + _currentTransaction = null; + } + } + } + + private async Task RollbackTransactionAsync(CancellationToken cancellationToken = default) + { + if (_currentTransaction != null) + { + await _currentTransaction.RollbackAsync(cancellationToken); + await _currentTransaction.DisposeAsync(); + _currentTransaction = null; + } + } + + private void CreateBatchCommand(object batch, string commandText, params object[] parameters) + { + if (batch is Data.SpannerBatchCommand command) + { + CreateBatchCommand(command, commandText, parameters); + } + else if (batch is DbBatch dbBatch) + { + CreateBatchCommand(dbBatch, commandText, parameters); + } + else + { + throw new ArgumentException("unknown batch type"); + } + } + + private void CreateBatchCommand(Data.SpannerBatchCommand batch, string commandText, params object[] parameters) + { + var paramCollection = new Data.SpannerParameterCollection(); + for (var i=0; i < parameters.Length; i++) + { + var value = parameters[i]; + if (value is decimal d) + { + value = PgNumeric.FromDecimal(d); + } + paramCollection.Add(new Data.SpannerParameter {ParameterName = $"p{i+1}", Value = value}); + } + batch.Add(commandText, paramCollection); + } + + private void CreateBatchCommand(DbBatch batch, string commandText, params object[] parameters) + { + var batchCommand = batch.CreateBatchCommand(); + batchCommand.CommandText = commandText; + for (var i = 0; i < parameters.Length; i++) + { + CreateParameter(batchCommand, $"p{i+1}", parameters[i]); + } + batch.BatchCommands.Add(batchCommand); + } + + private void CreateParameter(DbBatchCommand cmd, string parameterName, object parameterValue) + { + var parameter = cmd.CreateParameter(); + parameter.ParameterName = parameterName; + parameter.Value = parameterValue; + cmd.Parameters.Add(parameter); + } + + private async Task ExecuteNonQueryAsync(string commandText, + CancellationToken cancellationToken, params object[] parameters) + { + using var command = CreateCommand(commandText, parameters); + await command.ExecuteNonQueryAsync(cancellationToken); + } + + private Task ExecuteRowAsync(string commandText, + CancellationToken cancellationToken, params object[] parameters) + { + return ExecuteRowAsync(true, commandText, cancellationToken, parameters)!; + } + + private async Task ExecuteRowAsync(bool mustFindRow, string commandText, + CancellationToken cancellationToken, params object[] parameters) + { + using var command = CreateCommand(commandText, parameters); + using var reader = await command.ExecuteReaderAsync(cancellationToken); + if (!await reader.ReadAsync(cancellationToken)) + { + if (mustFindRow) + { + throw new RowNotFoundException("Row not found"); + } + return null; + } + var result = new object[reader.FieldCount]; + for (var i = 0; i < reader.FieldCount; i++) + { + result[i] = reader.GetValue(i); + } + return result; + } + + private async Task> ExecuteQueryAsync(string commandText, + CancellationToken cancellationToken, params object[] parameters) + { + using var command = CreateCommand(commandText, parameters); + using var reader = await command.ExecuteReaderAsync(cancellationToken); + var result = new List(); + while (await reader.ReadAsync(cancellationToken)) + { + var row = new object[reader.FieldCount]; + for (var i = 0; i < reader.FieldCount; i++) + { + row[i] = reader.GetValue(i); + } + result.Add(row); + } + return result; + } + + private DbCommand CreateCommand(string commandText, params object[] parameters) + { + var command = _connection.CreateCommand(); + command.CommandText = commandText; + command.Transaction = _currentTransaction; + for (var i = 0; i < parameters.Length; i++) + { + CreateParameter(command, $"p{i+1}", parameters[i]); + } + return command; + } + + private void CreateParameter(DbCommand cmd, string parameterName, object parameterValue) + { + var parameter = cmd.CreateParameter(); + parameter.ParameterName = parameterName; + if (_isClientLib) + { + var value = parameterValue; + if (value is decimal d) + { + value = PgNumeric.FromDecimal(d); + } + parameter.Value = value; + } + else + { + parameter.Value = parameterValue; + } + cmd.Parameters.Add(parameter); + } + + private long GetOtherWarehouseId(long currentId) { + if (_numWarehouses == 1) { + return currentId; + } + while (true) { + var otherId = DataLoader.ReverseBitsUnsigned((ulong)Random.Shared.Next(_numWarehouses)); + if (otherId != currentId) { + return otherId; + } + } + } + +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/loader/CustomerLoader.cs b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/loader/CustomerLoader.cs new file mode 100644 index 00000000..d5f82835 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/loader/CustomerLoader.cs @@ -0,0 +1,112 @@ +using System.Globalization; +using Google.Cloud.Spanner.V1; +using Google.Protobuf.WellKnownTypes; + +namespace Google.Cloud.Spanner.DataProvider.Benchmarks.tpcc.loader; + +internal class CustomerLoader +{ + private readonly SpannerConnection _connection; + + private readonly int _warehouseCount; + + private readonly int _districtsPerWarehouse; + + private readonly int _customersPerDistrict; + + internal CustomerLoader(SpannerConnection connection, int warehouseCount, int districtsPerWarehouse, int customersPerDistrict) + { + _connection = connection; + _warehouseCount = warehouseCount; + _districtsPerWarehouse = districtsPerWarehouse; + _customersPerDistrict = customersPerDistrict; + } + + internal async Task LoadAsync(CancellationToken cancellationToken = default) + { + var count = await CountAsync(cancellationToken); + if (count >= _warehouseCount * _districtsPerWarehouse * _customersPerDistrict) + { + return; + } + + for (var warehouse = 0; warehouse < _warehouseCount; warehouse++) + { + for (var district = 0; district < _districtsPerWarehouse; district++) + { + var group = new BatchWriteRequest.Types.MutationGroup + { + Mutations = { Capacity = 1 } + }; + group.Mutations.Add(CreateMutation(warehouse, district, _customersPerDistrict)); + await _connection.WriteMutationsAsync(group, cancellationToken); + } + } + } + + private async Task CountAsync(CancellationToken cancellationToken = default) + { + await using var command = _connection.CreateCommand(); + command.CommandText = "SELECT COUNT(1) FROM customer"; + var result = await command.ExecuteScalarAsync(cancellationToken); + return result == null ? 0L : (long) result; + } + + private Mutation CreateMutation(int warehouse, int district, int rows) + { + var mutation = new Mutation + { + InsertOrUpdate = new Mutation.Types.Write + { + Table = "customer", + Columns = { "c_id", "d_id", "w_id", "c_first", "c_middle", "c_last", "c_street_1", "c_street_2", + "c_city", "c_state", "c_zip", "c_phone", "c_since", "c_credit", "c_credit_lim", "c_discount", + "c_balance", "c_ytd_payment", "c_payment_cnt", "c_delivery_cnt", "c_data", + }, + Values = + { + Capacity = _customersPerDistrict, + } + } + }; + for (var i = 0; i < rows; i++) + { + mutation.InsertOrUpdate.Values.Add(CreateRandomCustomer(warehouse, district, i)); + } + return mutation; + } + + private ListValue CreateRandomCustomer(int warehouse, int district, int index) + { + var row = new ListValue + { + Values = + { + Capacity = 22 + } + }; + row.Values.Add(Value.ForString($"{DataLoader.ReverseBitsUnsigned((ulong) index)}")); + row.Values.Add(Value.ForString($"{DataLoader.ReverseBitsUnsigned((ulong) district)}")); + row.Values.Add(Value.ForString($"{DataLoader.ReverseBitsUnsigned((ulong) warehouse)}")); + row.Values.Add(Value.ForString(DataLoader.RandomString(16))); + row.Values.Add(Value.ForString(DataLoader.RandomString(2))); + row.Values.Add(Value.ForString(LastNameGenerator.GenerateLastName(index))); + row.Values.Add(Value.ForString(DataLoader.RandomString(20))); + row.Values.Add(Value.ForString(DataLoader.RandomString(20))); + row.Values.Add(Value.ForString(DataLoader.RandomString(20))); + row.Values.Add(Value.ForString(DataLoader.RandomString(2))); + row.Values.Add(Value.ForString(DataLoader.RandomString(9))); + row.Values.Add(Value.ForString(DataLoader.RandomString(16))); + row.Values.Add(Value.ForString(DataLoader.RandomTimestamp())); + row.Values.Add(Value.ForString(Random.Shared.Next(2) == 0 ? "GC" : "BC")); + row.Values.Add(Value.ForString(Random.Shared.Next(100, 5000).ToString(CultureInfo.InvariantCulture))); + row.Values.Add(Value.ForString(DataLoader.RandomDecimal(1, 40))); + row.Values.Add(Value.ForString("0.0")); + row.Values.Add(Value.ForString("0.0")); + row.Values.Add(Value.ForString("0")); + row.Values.Add(Value.ForString("0")); + row.Values.Add(Value.ForString(DataLoader.RandomString(500))); + + return row; + } +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/loader/DataLoader.cs b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/loader/DataLoader.cs new file mode 100644 index 00000000..a965b3d6 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/loader/DataLoader.cs @@ -0,0 +1,83 @@ +using System.Globalization; +using System.Xml; + +namespace Google.Cloud.Spanner.DataProvider.Benchmarks.tpcc.loader; + +public class DataLoader +{ + private readonly SpannerConnection _connection; + private readonly int _numWarehouses; + private readonly int _numDistrictsPerWarehouse; + private readonly int _numCustomersPerDistrict; + private readonly int _numItems; + + public DataLoader( + SpannerConnection connection, + int numWarehouses, + int numDistrictsPerWarehouse = 10, + int numCustomersPerDistrict = 3000, + int numItems = 100_000) + { + _connection = connection; + _numWarehouses = numWarehouses; + _numDistrictsPerWarehouse = numDistrictsPerWarehouse; + _numCustomersPerDistrict = numCustomersPerDistrict; + _numItems = numItems; + } + + public async Task LoadAsync(CancellationToken cancellationToken) + { + Console.WriteLine("Loading warehouses..."); + var warehouseLoader = new WarehouseLoader(_connection, _numWarehouses); + await warehouseLoader.LoadAsync(cancellationToken); + Console.WriteLine("Loading items..."); + var itemLoader = new ItemLoader(_connection, _numItems); + await itemLoader.LoadAsync(cancellationToken); + Console.WriteLine("Loading districts..."); + var districtLoader = new DistrictLoader(_connection, _numWarehouses, _numDistrictsPerWarehouse); + await districtLoader.LoadAsync(cancellationToken); + Console.WriteLine("Loading customers..."); + var customerLoader = new CustomerLoader(_connection, _numWarehouses, _numDistrictsPerWarehouse, _numCustomersPerDistrict); + await customerLoader.LoadAsync(cancellationToken); + Console.WriteLine("Loading stock..."); + var stockLoader = new StockLoader(_connection, _numWarehouses, _numItems); + await stockLoader.LoadAsync(cancellationToken); + } + + public static long ReverseBitsUnsigned(ulong n) + { + // Step 1: Swap adjacent bits + n = ((n >> 1) & 0x5555555555555555UL) | ((n & 0x5555555555555555UL) << 1); + // Step 2: Swap adjacent pairs of bits + n = ((n >> 2) & 0x3333333333333333UL) | ((n & 0x3333333333333333UL) << 2); + // Step 3: Swap adjacent nibbles (4 bits) + n = ((n >> 4) & 0x0F0F0F0F0F0F0F0FUL) | ((n & 0x0F0F0F0F0F0F0F0FUL) << 4); + // Step 4: Swap adjacent bytes + n = ((n >> 8) & 0x00FF00FF00FF00FFUL) | ((n & 0x00FF00FF00FF00FFUL) << 8); + // Step 5: Swap adjacent 2-byte words + n = ((n >> 16) & 0x0000FFFF0000FFFFUL) | ((n & 0x0000FFFF0000FFFFUL) << 16); + // Step 6: Swap the high and low 4-byte words (32 bits) + n = (n >> 32) | (n << 32); + return (long) n; + } + + internal static string RandomString(int length) + { + const string chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + return new string(Enumerable.Repeat(chars, length) + .Select(s => s[Random.Shared.Next(s.Length)]).ToArray()); + } + + internal static string RandomDecimal(int min, int max) + { + var d = (decimal) Random.Shared.Next(min, max) / 100; + return d.ToString("F", CultureInfo.InvariantCulture); + } + + internal static string RandomTimestamp() + { + var ts = DateTime.UtcNow.AddTicks(-Random.Shared.NextInt64(10 * 365 * TimeSpan.TicksPerDay)); + return XmlConvert.ToString(Convert.ToDateTime(ts, CultureInfo.InvariantCulture), + XmlDateTimeSerializationMode.Utc); + } +} diff --git a/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/loader/DistrictLoader.cs b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/loader/DistrictLoader.cs new file mode 100644 index 00000000..5e339853 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/loader/DistrictLoader.cs @@ -0,0 +1,89 @@ +using Google.Cloud.Spanner.V1; +using Google.Protobuf.WellKnownTypes; + +namespace Google.Cloud.Spanner.DataProvider.Benchmarks.tpcc.loader; + +internal class DistrictLoader +{ + private readonly SpannerConnection _connection; + + private readonly int _warehouseCount; + + private readonly int _districtsPerWarehouse; + + internal DistrictLoader(SpannerConnection connection, int warehouseCount, int districtsPerWarehouse) + { + _connection = connection; + _warehouseCount = warehouseCount; + _districtsPerWarehouse = districtsPerWarehouse; + } + + internal async Task LoadAsync(CancellationToken cancellationToken = default) + { + var count = await CountAsync(cancellationToken); + if (count >= _warehouseCount * _districtsPerWarehouse) + { + return; + } + for (var warehouse = 0; warehouse < _warehouseCount; warehouse++) + { + var group = new BatchWriteRequest.Types.MutationGroup + { + Mutations = { Capacity = 1 } + }; + group.Mutations.Add(CreateMutation(warehouse, _districtsPerWarehouse)); + await _connection.WriteMutationsAsync(group, cancellationToken); + } + } + + private async Task CountAsync(CancellationToken cancellationToken = default) + { + await using var command = _connection.CreateCommand(); + command.CommandText = "SELECT COUNT(1) FROM district"; + var result = await command.ExecuteScalarAsync(cancellationToken); + return result == null ? 0L : (long) result; + } + + private Mutation CreateMutation(int warehouse, int rows) + { + var mutation = new Mutation + { + InsertOrUpdate = new Mutation.Types.Write + { + Table = "district", + Columns = { "d_id", "w_id", "d_name", "d_street_1", "d_street_2", "d_city", "d_state", "d_zip", "d_tax", "d_ytd" }, + Values = + { + Capacity = _districtsPerWarehouse, + } + } + }; + for (var i = 0; i < rows; i++) + { + mutation.InsertOrUpdate.Values.Add(CreateRandomDistrict(warehouse, i)); + } + return mutation; + } + + private ListValue CreateRandomDistrict(int warehouse, int index) + { + var row = new ListValue + { + Values = + { + Capacity = 10 + } + }; + row.Values.Add(Value.ForString($"{DataLoader.ReverseBitsUnsigned((ulong) index)}")); + row.Values.Add(Value.ForString($"{DataLoader.ReverseBitsUnsigned((ulong) warehouse)}")); + row.Values.Add(Value.ForString($"W#{warehouse}D#{index}")); + row.Values.Add(Value.ForString(DataLoader.RandomString(20))); + row.Values.Add(Value.ForString(DataLoader.RandomString(20))); + row.Values.Add(Value.ForString(DataLoader.RandomString(20))); + row.Values.Add(Value.ForString(DataLoader.RandomString(2))); + row.Values.Add(Value.ForString(DataLoader.RandomString(9))); + row.Values.Add(Value.ForString(DataLoader.RandomDecimal(0, 21))); + row.Values.Add(Value.ForString("0.0")); + return row; + } +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/loader/ItemLoader.cs b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/loader/ItemLoader.cs new file mode 100644 index 00000000..aed4c3ee --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/loader/ItemLoader.cs @@ -0,0 +1,90 @@ +using Google.Cloud.Spanner.V1; +using Google.Protobuf.WellKnownTypes; + +namespace Google.Cloud.Spanner.DataProvider.Benchmarks.tpcc.loader; + +internal class ItemLoader +{ + private static readonly int RowsPerGroup = 1000; + + private readonly SpannerConnection _connection; + + private readonly int _rowCount; + + internal ItemLoader(SpannerConnection connection, int rowCount) + { + _connection = connection; + _rowCount = rowCount; + } + + internal async Task LoadAsync(CancellationToken cancellationToken = default) + { + var count = await CountAsync(cancellationToken); + if (count >= _rowCount) + { + return; + } + + var batch = 0; + var remaining = _rowCount; + while (remaining > 0) + { + var group = new BatchWriteRequest.Types.MutationGroup + { + Mutations = { Capacity = 1 } + }; + var rows = Math.Min(RowsPerGroup, remaining); + group.Mutations.Add(CreateMutation(batch, rows)); + await _connection.WriteMutationsAsync(group, cancellationToken); + remaining -= rows; + batch++; + } + } + + private async Task CountAsync(CancellationToken cancellationToken = default) + { + await using var command = _connection.CreateCommand(); + command.CommandText = "SELECT COUNT(1) FROM item"; + var result = await command.ExecuteScalarAsync(cancellationToken); + return result == null ? 0L : (long) result; + } + + private Mutation CreateMutation(int batch, int rows) + { + var mutation = new Mutation + { + InsertOrUpdate = new Mutation.Types.Write + { + Table = "item", + Columns = { "i_id", "i_im_id", "i_name", "i_price", "i_data" }, + Values = + { + Capacity = _rowCount, + } + } + }; + for (var i = 0; i < rows; i++) + { + mutation.InsertOrUpdate.Values.Add(CreateRandomItem(batch, i)); + } + return mutation; + } + + private ListValue CreateRandomItem(int batch, int index) + { + var row = new ListValue + { + Values = + { + Capacity = 5 + } + }; + var id = (long)batch * RowsPerGroup + index; + row.Values.Add(Value.ForString($"{DataLoader.ReverseBitsUnsigned((ulong) id)}")); + row.Values.Add(Value.ForString($"{Random.Shared.Next(1, 2000001)}")); + row.Values.Add(Value.ForString(DataLoader.RandomString(24))); + row.Values.Add(Value.ForString(DataLoader.RandomDecimal(100, 10001))); + row.Values.Add(Value.ForString(DataLoader.RandomString(50))); + return row; + } +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/loader/StockLoader.cs b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/loader/StockLoader.cs new file mode 100644 index 00000000..05cdff2c --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/loader/StockLoader.cs @@ -0,0 +1,119 @@ +using Google.Cloud.Spanner.V1; +using Google.Protobuf.WellKnownTypes; + +namespace Google.Cloud.Spanner.DataProvider.Benchmarks.tpcc.loader; + +internal class StockLoader +{ + private static readonly int RowsPerGroup = 1000; + + private readonly SpannerConnection _connection; + + private readonly int _warehouseCount; + + private readonly int _numItems; + + internal StockLoader(SpannerConnection connection, int warehouseCount, int numItems) + { + _connection = connection; + _warehouseCount = warehouseCount; + _numItems = numItems; + } + + internal async Task LoadAsync(CancellationToken cancellationToken = default) + { + var count = await CountAsync(cancellationToken); + if (count >= _warehouseCount * _numItems) + { + return; + } + for (var warehouse = 0; warehouse < _warehouseCount; warehouse++) + { + for (var item=0; item<_numItems; item += RowsPerGroup) + { + var group = new BatchWriteRequest.Types.MutationGroup + { + Mutations = { Capacity = 1 } + }; + group.Mutations.Add(CreateMutation(warehouse, item, RowsPerGroup)); + await _connection.WriteMutationsAsync(group, cancellationToken); + } + } + } + + private async Task CountAsync(CancellationToken cancellationToken = default) + { + await using var command = _connection.CreateCommand(); + command.CommandText = "SELECT COUNT(1) FROM stock"; + var result = await command.ExecuteScalarAsync(cancellationToken); + return result == null ? 0L : (long) result; + } + + private Mutation CreateMutation(int warehouse, int item, int rows) + { + var mutation = new Mutation + { + InsertOrUpdate = new Mutation.Types.Write + { + Table = "stock", + Columns = { "s_i_id", "w_id", "s_quantity", "s_dist_01", "s_dist_02", "s_dist_03", "s_dist_04", "s_dist_05", + "s_dist_06", "s_dist_07", "s_dist_08", "s_dist_09", "s_dist_10", "s_ytd", "s_order_cnt", "s_remote_cnt", "s_data" }, + Values = + { + Capacity = _numItems, + } + } + }; + for (var i = 0; i < rows; i++) + { + mutation.InsertOrUpdate.Values.Add(CreateRandomStock(warehouse, item, i)); + } + return mutation; + } + + private ListValue CreateRandomStock(int warehouse, int item, int index) + { + var row = new ListValue + { + Values = + { + Capacity = 10 + } + }; + // s_i_id int not null, + // w_id int not null, + // s_quantity int, + // s_dist_01 varchar(24), + // s_dist_02 varchar(24), + // s_dist_03 varchar(24), + // s_dist_04 varchar(24), + // s_dist_05 varchar(24), + // s_dist_06 varchar(24), + // s_dist_07 varchar(24), + // s_dist_08 varchar(24), + // s_dist_09 varchar(24), + // s_dist_10 varchar(24), + // s_ytd decimal, + // s_order_cnt int, + // s_remote_cnt int, + // s_data varchar(50), + row.Values.Add(Value.ForString($"{DataLoader.ReverseBitsUnsigned((ulong) (item + index))}")); + row.Values.Add(Value.ForString($"{DataLoader.ReverseBitsUnsigned((ulong) warehouse)}")); + row.Values.Add(Value.ForString(Random.Shared.Next(1, 500).ToString())); + row.Values.Add(Value.ForString(DataLoader.RandomString(24))); + row.Values.Add(Value.ForString(DataLoader.RandomString(24))); + row.Values.Add(Value.ForString(DataLoader.RandomString(24))); + row.Values.Add(Value.ForString(DataLoader.RandomString(24))); + row.Values.Add(Value.ForString(DataLoader.RandomString(24))); + row.Values.Add(Value.ForString(DataLoader.RandomString(24))); + row.Values.Add(Value.ForString(DataLoader.RandomString(24))); + row.Values.Add(Value.ForString(DataLoader.RandomString(24))); + row.Values.Add(Value.ForString(DataLoader.RandomString(24))); + row.Values.Add(Value.ForString(DataLoader.RandomString(24))); + row.Values.Add(Value.ForString("0.0")); + row.Values.Add(Value.ForString("0")); + row.Values.Add(Value.ForString("0")); + row.Values.Add(Value.ForString(DataLoader.RandomString(50))); + return row; + } +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/loader/WarehouseLoader.cs b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/loader/WarehouseLoader.cs new file mode 100644 index 00000000..a9cb6fd4 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-benchmarks/tpcc/loader/WarehouseLoader.cs @@ -0,0 +1,67 @@ +using Google.Cloud.Spanner.V1; +using Google.Protobuf.WellKnownTypes; + +namespace Google.Cloud.Spanner.DataProvider.Benchmarks.tpcc.loader; + +internal class WarehouseLoader +{ + private readonly SpannerConnection _connection; + + private readonly int _rowCount; + + internal WarehouseLoader(SpannerConnection connection, int rowCount) + { + _connection = connection; + _rowCount = rowCount; + } + + internal Task LoadAsync(CancellationToken cancellationToken = default) + { + return _connection.WriteMutationsAsync(new BatchWriteRequest.Types.MutationGroup + { + Mutations = { CreateMutation() } + }, cancellationToken); + } + + private Mutation CreateMutation() + { + var mutation = new Mutation + { + InsertOrUpdate = new Mutation.Types.Write + { + Table = "warehouse", + Columns = { "w_id", "w_name", "w_street_1", "w_street_2", "w_city", "w_state", "w_zip", "w_tax", "w_ytd" }, + Values = + { + Capacity = _rowCount, + } + } + }; + for (var i = 0; i < _rowCount; i++) + { + mutation.InsertOrUpdate.Values.Add(CreateRandomWarehouse(i)); + } + return mutation; + } + + private ListValue CreateRandomWarehouse(int index) + { + var row = new ListValue + { + Values = + { + Capacity = 9 + } + }; + row.Values.Add(Value.ForString($"{DataLoader.ReverseBitsUnsigned((ulong) index)}")); + row.Values.Add(Value.ForString($"W#{index}")); + row.Values.Add(Value.ForString(DataLoader.RandomString(20))); + row.Values.Add(Value.ForString(DataLoader.RandomString(20))); + row.Values.Add(Value.ForString(DataLoader.RandomString(20))); + row.Values.Add(Value.ForString(DataLoader.RandomString(2))); + row.Values.Add(Value.ForString(DataLoader.RandomString(9))); + row.Values.Add(Value.ForString(DataLoader.RandomDecimal(0, 21))); + row.Values.Add(Value.ForString("0.0")); + return row; + } +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-samples/Program.cs b/drivers/spanner-ado-net/spanner-ado-net-samples/Program.cs new file mode 100644 index 00000000..c5287edf --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-samples/Program.cs @@ -0,0 +1,37 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Google.Cloud.Spanner.DataProvider; + +using var connection = new SpannerConnection {ConnectionString = "/projects/appdev-soda-spanner-staging/instances/knut-test-ycsb/databases/knut-test-db"}; +connection.Open(); +using var cmd = connection.CreateCommand(); +cmd.CommandText = "select * from all_types where col_varchar is not null limit 10"; + +using var reader = cmd.ExecuteReader(); +for (int i = 0; i < reader.FieldCount; i++) +{ + Console.Write(reader.GetName(i)); + Console.Write("|"); +} +Console.WriteLine(); +while (reader.Read()) +{ + for (int i = 0; i < reader.FieldCount; i++) + { + Console.Write(reader.GetValue(i)); + Console.Write("|"); + } + Console.WriteLine(); +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-samples/README.md b/drivers/spanner-ado-net/spanner-ado-net-samples/README.md new file mode 100644 index 00000000..b07bb9b5 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-samples/README.md @@ -0,0 +1,5 @@ +# Spanner ADO.NET Data Provider Samples + +Samples for ADO.NET Data Provider for Spanner. + +__ALPHA: Not for production use__ diff --git a/drivers/spanner-ado-net/spanner-ado-net-samples/spanner-ado-net-samples.csproj b/drivers/spanner-ado-net/spanner-ado-net-samples/spanner-ado-net-samples.csproj new file mode 100644 index 00000000..3c247b26 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-samples/spanner-ado-net-samples.csproj @@ -0,0 +1,17 @@ + + + + Exe + net8.0 + Google.Cloud.Spanner.DataProvider.Samples + enable + enable + Google.Cloud.Spanner.DataProvider.Samples + default + + + + + + + diff --git a/drivers/spanner-ado-net/spanner-ado-net-specification-tests/CommandTests.cs b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/CommandTests.cs new file mode 100644 index 00000000..a2feb8ab --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/CommandTests.cs @@ -0,0 +1,123 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using AdoNet.Specification.Tests; +using Google.Cloud.SpannerLib.MockServer; +using Xunit; +using TypeCode = Google.Cloud.Spanner.V1.TypeCode; + +namespace Google.Cloud.Spanner.DataProvider.SpecificationTests; + +public class CommandTests(DbFactoryFixture fixture) : CommandTestBase(fixture) +{ + [Fact] + public override void Execute_throws_for_unknown_ParameterValue_type() + { + Fixture.MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT @Parameter;", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.String}, "Parameter", new CustomClass().ToString())); + base.Execute_throws_for_unknown_ParameterValue_type(); + } + + [Fact] + public override void ExecuteReader_binds_parameters() + { + Fixture.MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT @Parameter;", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.Int64}, "Parameter", 1L)); + base.ExecuteReader_binds_parameters(); + } + + [Fact] + public override void ExecuteReader_supports_CloseConnection() + { + Fixture.MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT 0;", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.Int64}, "c", 0L)); + base.ExecuteReader_supports_CloseConnection(); + } + + [Fact] + public override void ExecuteReader_works_when_trailing_comments() + { + Fixture.MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT 0; -- My favorite number", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.Int64}, "c", 0L)); + base.ExecuteReader_works_when_trailing_comments(); + } + + [Fact] + public override void ExecuteScalar_returns_DBNull_when_null() + { + Fixture.MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT NULL;", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.Int64}, "c", DBNull.Value)); + base.ExecuteScalar_returns_DBNull_when_null(); + } + + [Fact] + public override void ExecuteScalar_returns_first_when_multiple_columns() + { + Fixture.MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT 42, 43;", + StatementResult.CreateResultSet([Tuple.Create(TypeCode.Int64, "c1"), Tuple.Create(TypeCode.Int64, "c1")], [[42L, 43L]])); + base.ExecuteScalar_returns_first_when_multiple_columns(); + } + + [Fact] + public override void ExecuteScalar_returns_real() + { + Fixture.MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT 3.14;", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.Float64}, "c", 3.14d)); + base.ExecuteScalar_returns_real(); + } + + [Fact] + public override void ExecuteScalar_returns_string_when_text() + { + Fixture.MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT 'test';", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.String}, "c", "test")); + base.ExecuteScalar_returns_string_when_text(); + } + + [Fact(Skip = "Spanner does not support multiple SQL statements in one string")] + public override void ExecuteScalar_returns_first_when_batching() + { + } + + [Fact(Skip = "Spanner does not support empty statements")] + public override void ExecuteReader_HasRows_is_false_for_comment() + { + } + + [Fact(Skip = "Spanner does not use the command text once the reader has been opened")] + public override void CommandText_throws_when_set_when_open_reader() + { + } + + [Fact(Skip = "Spanner does not need the connection after the reader has been opened")] + public override void Connection_throws_when_set_when_open_reader() + { + } + + [Fact(Skip = "Spanner does not need the connection after the reader has been opened")] + public override void Connection_throws_when_set_to_null_when_open_reader() + { + } + + [Fact(Skip = "Spanner supports multiple open readers for one command")] + public override void ExecuteReader_throws_when_reader_open() + { + } + + [Fact(Skip = "Spanner only supports one transaction per connection and therefore ignores the transaction property")] + public override void ExecuteReader_throws_when_transaction_required() + { + } + +} diff --git a/drivers/spanner-ado-net/spanner-ado-net-specification-tests/ConnectionStringTests.cs b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/ConnectionStringTests.cs new file mode 100644 index 00000000..0160687d --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/ConnectionStringTests.cs @@ -0,0 +1,19 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using AdoNet.Specification.Tests; + +namespace Google.Cloud.Spanner.DataProvider.SpecificationTests; + +public class ConnectionStringTests(DbFactoryFixture fixture) : ConnectionStringTestBase(fixture); \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-specification-tests/ConnectionTests.cs b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/ConnectionTests.cs new file mode 100644 index 00000000..eb8f2dcb --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/ConnectionTests.cs @@ -0,0 +1,19 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using AdoNet.Specification.Tests; + +namespace Google.Cloud.Spanner.DataProvider.SpecificationTests; + +public class ConnectionTests(DbFactoryFixture fixture) : ConnectionTestBase(fixture); \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-specification-tests/DataReaderTests.cs b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/DataReaderTests.cs new file mode 100644 index 00000000..028b0836 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/DataReaderTests.cs @@ -0,0 +1,107 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using AdoNet.Specification.Tests; +using Google.Cloud.SpannerLib.MockServer; +using Xunit; +using TypeCode = Google.Cloud.Spanner.V1.TypeCode; + +namespace Google.Cloud.Spanner.DataProvider.SpecificationTests; + +public class DataReaderTests(DbFactoryFixture fixture) : DataReaderTestBase(fixture) +{ + [Fact(Skip = "SpannerLib does not support multiple statements in one query string")] + public override void HasRows_works_when_batching() + { + } + + [Fact(Skip = "SpannerLib does not support multiple statements in one query string")] + public override void NextResult_works() + { + } + + [Fact(Skip = "SpannerLib does not support multiple statements in one query string")] + public override void SingleResult_returns_one_result_set() + { + } + + [Fact(Skip = "SpannerLib does not support multiple statements in one query string")] + public override void SingleRow_returns_one_result_set() + { + } + + [Fact(Skip = "Getting stats after closing a DataReader is not supported")] + public override void RecordsAffected_returns_negative_1_after_close_when_no_rows() + { + } + + [Fact(Skip = "Getting stats after closing a DataReader is not supported")] + public override void RecordsAffected_returns_negative_1_after_dispose_when_no_rows() + { + } + + public override void GetFieldValue_works_utf8_four_bytes() + { + Fixture.MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT '😀';", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.String}, "c", "😀")); + base.GetFieldValue_works_utf8_four_bytes(); + } + + public override void GetString_works_utf8_four_bytes() + { + Fixture.MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT '😀';", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.String}, "c", "😀")); + base.GetString_works_utf8_four_bytes(); + } + + public override void GetValue_to_string_works_utf8_four_bytes() + { + Fixture.MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT '😀';", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.String}, "c", "😀")); + base.GetValue_to_string_works_utf8_four_bytes(); + } + + public override void GetFieldValue_works_utf8_three_bytes() + { + Fixture.MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT 'Ḁ';", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.String}, "c", "Ḁ")); + base.GetFieldValue_works_utf8_three_bytes(); + } + + public override void GetFieldValue_works_utf8_two_bytes() + { + Fixture.MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT 'Ä';", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.String}, "c", "Ä")); + base.GetFieldValue_works_utf8_two_bytes(); + } + + public override void GetValues_works() + { + Fixture.MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT 'a', NULL;", + StatementResult.CreateResultSet([Tuple.Create(TypeCode.String, "c1"), Tuple.Create(TypeCode.Int64, "c2")], [["a", DBNull.Value]])); + base.GetValues_works(); + } + + public override void Item_by_name_works() + { + Fixture.MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT 'test' AS Id;", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.String}, "Id", "test")); + base.Item_by_name_works(); + } + + [Fact(Skip = "The default implementation of GetTextReader returns an empty reader for null values")] + public override void GetTextReader_throws_for_null_String() + { + } +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-specification-tests/DbFactoryFixture.cs b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/DbFactoryFixture.cs new file mode 100644 index 00000000..6c4e39e2 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/DbFactoryFixture.cs @@ -0,0 +1,441 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Data; +using System.Data.Common; +using AdoNet.Specification.Tests; +using Google.Cloud.SpannerLib.MockServer; +using TypeCode = Google.Cloud.Spanner.V1.TypeCode; + +namespace Google.Cloud.Spanner.DataProvider.SpecificationTests; + +public class DbFactoryFixture : IDisposable, ISelectValueFixture, IDeleteFixture +{ + static DbFactoryFixture() + { + AppDomain.CurrentDomain.ProcessExit += (_, _) => + { + SpannerPool.CloseSpannerLib(); + }; + } + + private bool _disposed; + internal readonly SpannerMockServerFixture MockServerFixture = new (); + + public DbProviderFactory Factory => SpannerFactory.Instance; + public string ConnectionString => $"Host={MockServerFixture.Host};Port={MockServerFixture.Port};Data Source=projects/p1/instances/i1/databases/d1;UsePlainText=true"; + + public IReadOnlyCollection SupportedDbTypes { get; } = [ + DbType.Binary, + DbType.Boolean, + DbType.Date, + DbType.DateTime, + DbType.Decimal, + DbType.Double, + DbType.Guid, + DbType.Int64, + DbType.Single, + DbType.String, + ]; + public string SelectNoRows => "select * from (select 1) where false"; + public System.Type NullValueExceptionType { get; } = typeof(InvalidCastException); + public string DeleteNoRows => "delete from foo where false"; + + public DbFactoryFixture() + { + Reset(); + } + + public void Reset() + { + MockServerFixture.SpannerMock.Reset(); + MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT 1;", StatementResult.CreateSelect1ResultSet()); + MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT 1", StatementResult.CreateSelect1ResultSet()); + MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT NULL;", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.Int64}, "c", DBNull.Value)); + MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT 1 AS id;", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.Int64}, "id", 1)); + MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT 1 AS Id;", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.Int64}, "Id", 1)); + MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT 'test';", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.String}, "c", "test")); + MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT 'ab¢d';", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.String}, "c", "ab¢d")); + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(SelectNoRows, + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.Int64}, "c")); + MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT 42 UNION SELECT 43;", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.Int64}, "c", 42, 43)); + MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT 1 UNION SELECT 2;", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.Int64}, "c", 1, 2)); + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(DeleteNoRows, StatementResult.CreateUpdateCount(0)); + } + + public string CreateSelectSql(DbType dbType, ValueKind kind) + { + return dbType switch + { + DbType.Binary => CreateSelectSqlBinary(kind), + DbType.Boolean => CreateSelectSqlBoolean(kind), + DbType.Date => CreateSelectSqlDate(kind), + DbType.DateTime => CreateSelectSqlDateTime(kind), + DbType.Decimal => CreateSelectSqlDecimal(kind), + DbType.Double => CreateSelectSqlDouble(kind), + DbType.Guid => CreateSelectSqlGuid(kind), + DbType.Int64 => CreateSelectSqlInt64(kind), + DbType.Single => CreateSelectSqlSingle(kind), + DbType.String => CreateSelectSqlString(kind), + _ => throw new NotImplementedException("Not implemented") + }; + } + + private string CreateSelectSqlBinary(ValueKind kind) + { + var sql = "SELECT bytes_col FROM my_table;"; + switch (kind) + { + case ValueKind.Empty: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Bytes }, "bytes_col", Array.Empty())); + break; + case ValueKind.Zero: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Bytes }, "bytes_col", new byte[]{0})); + break; + case ValueKind.One: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Bytes }, "bytes_col", new byte[]{0x11})); + break; + case ValueKind.Maximum: + case ValueKind.Minimum: + case ValueKind.Null: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Bytes }, "bytes_col", DBNull.Value)); + break; + default: + throw new NotImplementedException("Not implemented"); + } + return sql; + } + private string CreateSelectSqlBoolean(ValueKind kind) + { + var sql = "SELECT bool_col FROM my_table;"; + switch (kind) + { + case ValueKind.Maximum: + case ValueKind.One: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Bool }, "bool_col", true)); + break; + case ValueKind.Empty: + case ValueKind.Minimum: + case ValueKind.Zero: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Bool }, "bool_col", false)); + break; + case ValueKind.Null: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Bool }, "bool_col", DBNull.Value)); + break; + default: + throw new NotImplementedException("Not implemented"); + } + return sql; + } + + private string CreateSelectSqlDate(ValueKind kind) + { + var sql = "SELECT date_col FROM my_table;"; + switch (kind) + { + case ValueKind.One: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Date }, "date_col", "1111-11-11")); + break; + case ValueKind.Maximum: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Date }, "date_col", "9999-12-31")); + break; + case ValueKind.Minimum: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Date }, "date_col", "0001-01-01")); + break; + case ValueKind.Zero: + case ValueKind.Empty: + case ValueKind.Null: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Date }, "date_col", DBNull.Value)); + break; + default: + throw new NotImplementedException("Not implemented"); + } + return sql; + } + + private string CreateSelectSqlDateTime(ValueKind kind) + { + var sql = "SELECT timestamp_col FROM my_table;"; + switch (kind) + { + case ValueKind.One: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Timestamp }, "timestamp_col", "1111-11-11T11:11:11.111000000Z")); + break; + case ValueKind.Maximum: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Timestamp }, "timestamp_col", "9999-12-31T23:59:59.999000000Z")); + break; + case ValueKind.Minimum: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Timestamp }, "timestamp_col", "0001-01-01T00:00:00Z")); + break; + case ValueKind.Zero: + case ValueKind.Empty: + case ValueKind.Null: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Timestamp }, "timestamp_col", DBNull.Value)); + break; + default: + throw new NotImplementedException("Not implemented"); + } + return sql; + } + + private string CreateSelectSqlDecimal(ValueKind kind) + { + var sql = "SELECT numeric_col FROM my_table;"; + switch (kind) + { + case ValueKind.Zero: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Numeric }, "numeric_col", "0")); + break; + case ValueKind.One: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Numeric }, "numeric_col", "1")); + break; + case ValueKind.Maximum: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Numeric }, "numeric_col", "99999999999999999999.999999999999999")); + break; + case ValueKind.Minimum: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Numeric }, "numeric_col", "0.000000000000001")); + break; + case ValueKind.Empty: + case ValueKind.Null: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Numeric }, "numeric_col", DBNull.Value)); + break; + default: + throw new NotImplementedException("Not implemented"); + } + return sql; + } + + private string CreateSelectSqlDouble(ValueKind kind) + { + var sql = "SELECT float64_col FROM my_table;"; + switch (kind) + { + case ValueKind.Zero: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Float64 }, "float64_col", 0.0d)); + break; + case ValueKind.One: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Float64 }, "float64_col", 1.0d)); + break; + case ValueKind.Maximum: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Float64 }, "float64_col", 1.79e308d)); + break; + case ValueKind.Minimum: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Float64 }, "float64_col", 2.23e-308d)); + break; + case ValueKind.Empty: + case ValueKind.Null: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Float64 }, "float64_col", DBNull.Value)); + break; + default: + throw new NotImplementedException("Not implemented"); + } + return sql; + } + + private string CreateSelectSqlGuid(ValueKind kind) + { + var sql = "SELECT uuid_col FROM my_table;"; + switch (kind) + { + case ValueKind.Zero: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Uuid }, "uuid_col", "00000000-0000-0000-0000-000000000000")); + break; + case ValueKind.One: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Uuid }, "uuid_col", "11111111-1111-1111-1111-111111111111")); + break; + case ValueKind.Maximum: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Uuid }, "uuid_col", "ccddeeff-aabb-8899-7766-554433221100")); + break; + case ValueKind.Minimum: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Uuid }, "uuid_col", "33221100-5544-7766-9988-aabbccddeeff")); + break; + case ValueKind.Empty: + case ValueKind.Null: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Uuid }, "uuid_col", DBNull.Value)); + break; + default: + throw new NotImplementedException("Not implemented"); + } + return sql; + } + + private string CreateSelectSqlInt64(ValueKind kind) + { + var sql = "SELECT int64_col FROM my_table;"; + switch (kind) + { + case ValueKind.Zero: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Int64 }, "int64_col", 0L)); + break; + case ValueKind.One: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Int64 }, "int64_col", 1L)); + break; + case ValueKind.Maximum: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Int64 }, "int64_col", long.MaxValue)); + break; + case ValueKind.Minimum: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Int64 }, "int64_col", long.MinValue)); + break; + case ValueKind.Empty: + case ValueKind.Null: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Int64 }, "int64_col", DBNull.Value)); + break; + default: + throw new NotImplementedException("Not implemented"); + } + return sql; + } + + private string CreateSelectSqlSingle(ValueKind kind) + { + var sql = "SELECT float32_col FROM my_table;"; + switch (kind) + { + case ValueKind.Zero: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Float32 }, "float32_col", 0.0f)); + break; + case ValueKind.One: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Float32 }, "float32_col", 1.0f)); + break; + case ValueKind.Maximum: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Float32 }, "float32_col", 3.40e38f)); + break; + case ValueKind.Minimum: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Float32 }, "float32_col", 1.18e-38f)); + break; + case ValueKind.Empty: + case ValueKind.Null: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.Float32 }, "float32_col", DBNull.Value)); + break; + default: + throw new NotImplementedException("Not implemented"); + } + return sql; + } + + private string CreateSelectSqlString(ValueKind kind) + { + var sql = "SELECT string_col FROM my_table;"; + switch (kind) + { + case ValueKind.Zero: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.String }, "string_col", "0")); + break; + case ValueKind.One: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.String }, "string_col", "1")); + break; + case ValueKind.Empty: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.String }, "string_col", "")); + break; + case ValueKind.Maximum: + case ValueKind.Minimum: + case ValueKind.Null: + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type { Code = TypeCode.String }, "string_col", DBNull.Value)); + break; + default: + throw new NotImplementedException("Not implemented"); + } + return sql; + } + + public string CreateSelectSql(byte[] value) + { + var sql = "SELECT bytes_col FROM my_table;"; + MockServerFixture.SpannerMock.AddOrUpdateStatementResult(sql, + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.Bytes}, "bytes_col", value)); + + return sql; + } + + protected void MarkDisposed() + { + _disposed = true; + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + try + { + MockServerFixture.Dispose(); + // var source = new CancellationTokenSource(); + // source.CancelAfter(1000); + // Task.Run(() => SpannerPool.CloseSpannerLibWhenAllConnectionsClosedAsync(source.Token), source.Token).Wait(source.Token); + } + finally + { + _disposed = true; + } + } +} diff --git a/drivers/spanner-ado-net/spanner-ado-net-specification-tests/DbFactoryTests.cs b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/DbFactoryTests.cs new file mode 100644 index 00000000..7a34fecc --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/DbFactoryTests.cs @@ -0,0 +1,19 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using AdoNet.Specification.Tests; + +namespace Google.Cloud.Spanner.DataProvider.SpecificationTests; + +public class DbFactoryTests(DbFactoryFixture fixture) : DbFactoryTestBase(fixture); \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-specification-tests/DbProviderFactoryTests.cs b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/DbProviderFactoryTests.cs new file mode 100644 index 00000000..6ddf7eee --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/DbProviderFactoryTests.cs @@ -0,0 +1,19 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using AdoNet.Specification.Tests; + +namespace Google.Cloud.Spanner.DataProvider.SpecificationTests; + +public class DbProviderFactoryTests(DbFactoryFixture fixture) : DbProviderFactoryTestBase(fixture); \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-specification-tests/GetValueConversionTests.cs b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/GetValueConversionTests.cs new file mode 100644 index 00000000..4f4870f0 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/GetValueConversionTests.cs @@ -0,0 +1,138 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Data; +using System.Globalization; +using AdoNet.Specification.Tests; + +namespace Google.Cloud.Spanner.DataProvider.SpecificationTests; + +public class GetValueConversionTests(DbFactoryFixture fixture) : GetValueConversionTestBase(fixture) +{ + // Spanner uses DateOnly for DATE columns. + public override void GetFieldType_for_Date() => TestGetFieldType(DbType.Date, ValueKind.One, typeof(DateOnly)); + + public override void GetValue_for_Date() => TestGetValue(DbType.Date, ValueKind.One, new DateOnly(1111, 11, 11)); + + + // Spanner allows string values to be cast to numerical values. + public override void GetDecimal_throws_for_zero_String() => TestGetValue(DbType.String, ValueKind.Zero, x => x.GetDecimal(0), 0.0m); + + public override void GetDecimal_throws_for_one_String() => TestGetValue(DbType.String, ValueKind.One, x => x.GetDecimal(0), 1.0m); + + public override void GetDouble_throws_for_zero_String() => TestGetValue(DbType.String, ValueKind.Zero, x => x.GetDouble(0), 0.0d); + + public override void GetDouble_throws_for_one_String() => TestGetValue(DbType.String, ValueKind.One, x => x.GetDouble(0), 1.0d); + + public override void GetDouble_throws_for_zero_String_with_GetFieldValue() => TestGetValue(DbType.String, ValueKind.Zero, x => x.GetFieldValue(0), 0.0d); + + public override async Task GetDouble_throws_for_zero_String_with_GetFieldValueAsync() => await TestGetValueAsync(DbType.String, ValueKind.Zero, async x => await x.GetFieldValueAsync(0), 0.0d); + + public override void GetFloat_throws_for_one_String() => TestGetValue(DbType.String, ValueKind.One, x => x.GetFloat(0), 1.0f); + + public override void GetFloat_throws_for_zero_String() => TestGetValue(DbType.String, ValueKind.Zero, x => x.GetFloat(0), 0.0f); + + public override void GetFloat_throws_for_zero_String_with_GetFieldValue() => TestGetValue(DbType.String, ValueKind.Zero, x => x.GetFieldValue(0), 0.0f); + + public override async Task GetFloat_throws_for_zero_String_with_GetFieldValueAsync() => await TestGetValueAsync(DbType.String, ValueKind.Zero, async x => await x.GetFieldValueAsync(0), 0.0f); + + public override void GetInt16_throws_for_one_String() => TestGetValue(DbType.String, ValueKind.One, x => x.GetInt16(0), (short) 1); + + public override void GetInt16_throws_for_zero_String() => TestGetValue(DbType.String, ValueKind.Zero, x => x.GetInt16(0), (short) 0); + + public override void GetInt16_throws_for_zero_String_with_GetFieldValue() => TestGetValue(DbType.String, ValueKind.Zero, x => x.GetFieldValue(0), (short) 0); + + public override async Task GetInt16_throws_for_zero_String_with_GetFieldValueAsync() => await TestGetValueAsync(DbType.String, ValueKind.Zero, async x => await x.GetFieldValueAsync(0), (short) 0); + + public override void GetInt32_throws_for_one_String() => TestGetValue(DbType.String, ValueKind.One, x => x.GetInt32(0), 1); + + public override void GetInt32_throws_for_zero_String() => TestGetValue(DbType.String, ValueKind.Zero, x => x.GetInt32(0), 0); + + public override void GetInt32_throws_for_zero_String_with_GetFieldValue() => TestGetValue(DbType.String, ValueKind.Zero, x => x.GetFieldValue(0), 0); + + public override async Task GetInt32_throws_for_zero_String_with_GetFieldValueAsync() => await TestGetValueAsync(DbType.String, ValueKind.Zero, async x => await x.GetFieldValueAsync(0), 0); + + public override void GetInt64_throws_for_one_String() => TestGetValue(DbType.String, ValueKind.One, x => x.GetInt64(0), 1L); + + public override void GetInt64_throws_for_zero_String() => TestGetValue(DbType.String, ValueKind.Zero, x => x.GetInt64(0), 0L); + + public override void GetInt64_throws_for_zero_String_with_GetFieldValue() => TestGetValue(DbType.String, ValueKind.Zero, x => x.GetFieldValue(0), 0L); + + public override async Task GetInt64_throws_for_zero_String_with_GetFieldValueAsync() => await TestGetValueAsync(DbType.String, ValueKind.Zero, async x => await x.GetFieldValueAsync(0), 0L); + + public override void GetString_throws_for_maximum_Boolean() => TestGetValue(DbType.Boolean, ValueKind.Maximum, x => x.GetString(0), "True"); + + public override void GetString_throws_for_maximum_Decimal() => TestGetValue(DbType.Decimal, ValueKind.Maximum, x => x.GetString(0), "99999999999999999999.999999999999999"); + + public override void GetString_throws_for_maximum_Double() => TestGetValue(DbType.Double, ValueKind.Maximum, x => x.GetString(0), "1.79E+308"); + + public override void GetString_throws_for_maximum_Int64() => TestGetValue(DbType.Int64, ValueKind.Maximum, x => x.GetString(0), long.MaxValue.ToString(CultureInfo.InvariantCulture)); + + public override void GetString_throws_for_maximum_Single() => TestGetValue(DbType.Single, ValueKind.Maximum, x => x.GetString(0), 3.40e38f.ToString(CultureInfo.InvariantCulture)); + + public override void GetString_throws_for_minimum_Boolean() => TestGetValue(DbType.Boolean, ValueKind.Minimum, x => x.GetString(0), "False"); + + public override void GetString_throws_for_minimum_Decimal() => TestGetValue(DbType.Decimal, ValueKind.Minimum, x => x.GetString(0), "0.000000000000001"); + + public override void GetString_throws_for_minimum_Double() => TestGetValue(DbType.Double, ValueKind.Minimum, x => x.GetString(0), "2.23E-308"); + + public override void GetString_throws_for_minimum_Int64() => TestGetValue(DbType.Int64, ValueKind.Minimum, x => x.GetString(0), long.MinValue.ToString(CultureInfo.InvariantCulture)); + + public override void GetString_throws_for_minimum_Single() => TestGetValue(DbType.Single, ValueKind.Minimum, x => x.GetString(0), "1.18E-38"); + + public override void GetString_throws_for_one_Boolean() => TestGetValue(DbType.Boolean, ValueKind.One, x => x.GetString(0), "True"); + + public override void GetString_throws_for_one_Decimal() => TestGetValue(DbType.Decimal, ValueKind.One, x => x.GetString(0), "1"); + + public override void GetString_throws_for_one_Double() => TestGetValue(DbType.Double, ValueKind.One, x => x.GetString(0), "1"); + + public override void GetString_throws_for_one_Guid() => TestGetValue(DbType.Guid, ValueKind.One, x => x.GetString(0), "11111111-1111-1111-1111-111111111111"); + + public override void GetString_throws_for_one_Int64() => TestGetValue(DbType.Int64, ValueKind.One, x => x.GetString(0), "1"); + + public override void GetString_throws_for_one_Single() => TestGetValue(DbType.Single, ValueKind.One, x => x.GetString(0), "1"); + + public override void GetString_throws_for_zero_Boolean() => TestGetValue(DbType.Boolean, ValueKind.Zero, x => x.GetString(0), "False"); + + public override void GetString_throws_for_zero_Decimal() => TestGetValue(DbType.Decimal, ValueKind.Zero, x => x.GetString(0), "0"); + + public override void GetString_throws_for_zero_Double() => TestGetValue(DbType.Double, ValueKind.Zero, x => x.GetString(0), "0"); + + public override void GetString_throws_for_zero_Guid() => TestGetValue(DbType.Guid, ValueKind.Zero, x => x.GetString(0), "00000000-0000-0000-0000-000000000000"); + + public override void GetString_throws_for_zero_Int64() => TestGetValue(DbType.Int64, ValueKind.Zero, x => x.GetString(0), "0"); + + public override void GetString_throws_for_zero_Single() => TestGetValue(DbType.Single, ValueKind.Zero, x => x.GetString(0), "0"); + + public override void GetDouble_throws_for_one_String_with_GetFieldValue() => TestGetValue(DbType.String, ValueKind.One, x => x.GetFieldValue(0), 1.0d); + + public override async Task GetDouble_throws_for_one_String_with_GetFieldValueAsync() => await TestGetValueAsync(DbType.String, ValueKind.One, async x => await x.GetFieldValueAsync(0), 1.0d); + + public override void GetFloat_throws_for_one_String_with_GetFieldValue() => TestGetValue(DbType.String, ValueKind.One, x => x.GetFieldValue(0), 1.0f); + + public override async Task GetFloat_throws_for_one_String_with_GetFieldValueAsync() => await TestGetValueAsync(DbType.String, ValueKind.One, async x => await x.GetFieldValueAsync(0), 1.0f); + + public override void GetInt16_throws_for_one_String_with_GetFieldValue() => TestGetValue(DbType.String, ValueKind.One, x => x.GetFieldValue(0), (short) 1); + + public override async Task GetInt16_throws_for_one_String_with_GetFieldValueAsync() => await TestGetValueAsync(DbType.String, ValueKind.One, async x => await x.GetFieldValueAsync(0), (short) 1); + + public override void GetInt32_throws_for_one_String_with_GetFieldValue() => TestGetValue(DbType.String, ValueKind.One, x => x.GetFieldValue(0), 1); + + public override async Task GetInt32_throws_for_one_String_with_GetFieldValueAsync() => await TestGetValueAsync(DbType.String, ValueKind.One, async x => await x.GetFieldValueAsync(0), 1); + + public override void GetInt64_throws_for_one_String_with_GetFieldValue() => TestGetValue(DbType.String, ValueKind.One, x => x.GetFieldValue(0), 1L); + + public override async Task GetInt64_throws_for_one_String_with_GetFieldValueAsync() => await TestGetValueAsync(DbType.String, ValueKind.One, async x => await x.GetFieldValueAsync(0), 1L); + +} diff --git a/drivers/spanner-ado-net/spanner-ado-net-specification-tests/ParameterTests.cs b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/ParameterTests.cs new file mode 100644 index 00000000..1a33a87b --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/ParameterTests.cs @@ -0,0 +1,83 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using AdoNet.Specification.Tests; +using Google.Cloud.Spanner.V1; +using Google.Cloud.SpannerLib.MockServer; +using Xunit; +using TypeCode = Google.Cloud.Spanner.V1.TypeCode; + +namespace Google.Cloud.Spanner.DataProvider.SpecificationTests; + +public class ParameterTests(DbFactoryFixture fixture) : ParameterTestBase(fixture) +{ + protected override Task OnInitializeAsync() + { + Fixture.Reset(); + return base.OnInitializeAsync(); + } + + [Fact(Skip = "Spanner assumes that it is a positional parameter if it has no name")] + public override void Bind_requires_set_name() + { + } + + [Fact(Skip = "Unknown parameters are converted to strings")] + public override void Bind_throws_when_unknown() + { + } + + [Fact] + public override void Bind_works_with_byte_array() + { + Fixture.MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT @Parameter;", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.Bytes}, "Parameter", new byte[]{1,2,3,4})); + base.Bind_works_with_byte_array(); + + var requests = Fixture.MockServerFixture.SpannerMock.Requests.OfType(); + var request = Assert.Single(requests); + Assert.Equal(Convert.ToBase64String(new byte[]{1,2,3,4}), request.Params.Fields["Parameter"].StringValue); + // The parameter value should be sent as an untyped string. + Assert.False(request.ParamTypes.ContainsKey("Parameter")); + } + + [Fact] + public override void Bind_works_with_stream() + { + Fixture.MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT @Parameter;", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.Bytes}, "Parameter", new byte[]{1,2,3,4})); + base.Bind_works_with_stream(); + + var requests = Fixture.MockServerFixture.SpannerMock.Requests.OfType(); + var request = Assert.Single(requests); + Assert.Equal(Convert.ToBase64String(new byte[]{1,2,3,4}), request.Params.Fields["Parameter"].StringValue); + // The parameter value should be sent as an untyped string. + Assert.False(request.ParamTypes.ContainsKey("Parameter")); + } + + [Fact] + public override void Bind_works_with_string() + { + Fixture.MockServerFixture.SpannerMock.AddOrUpdateStatementResult("SELECT @Parameter;", + StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.String}, "Parameter", "test")); + base.Bind_works_with_string(); + + var requests = Fixture.MockServerFixture.SpannerMock.Requests.OfType(); + var request = Assert.Single(requests); + Assert.Equal("test", request.Params.Fields["Parameter"].StringValue); + // The parameter value should be sent as an untyped string. + Assert.False(request.ParamTypes.ContainsKey("Parameter")); + } + +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-specification-tests/README.md b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/README.md new file mode 100644 index 00000000..9bbfd808 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/README.md @@ -0,0 +1,5 @@ +# Spanner ADO.NET Data Provider Specification Tests + +Specification tests for ADO.NET Data Provider for Spanner. + +__ALPHA: Not for production use__ diff --git a/drivers/spanner-ado-net/spanner-ado-net-specification-tests/TransactionTests.cs b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/TransactionTests.cs new file mode 100644 index 00000000..faf8b719 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/TransactionTests.cs @@ -0,0 +1,19 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using AdoNet.Specification.Tests; + +namespace Google.Cloud.Spanner.DataProvider.SpecificationTests; + +public class TransactionTests(DbFactoryFixture fixture) : TransactionTestBase(fixture); \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-specification-tests/appsettings.json b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/appsettings.json new file mode 100644 index 00000000..2a8537d3 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/appsettings.json @@ -0,0 +1,8 @@ +{ + "Logging": { + "IncludeScopes": false, + "LogLevel": { + "Microsoft": "Warning" + } + } +} diff --git a/drivers/spanner-ado-net/spanner-ado-net-specification-tests/spanner-ado-net-specification-tests.csproj b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/spanner-ado-net-specification-tests.csproj new file mode 100644 index 00000000..905694fb --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-specification-tests/spanner-ado-net-specification-tests.csproj @@ -0,0 +1,34 @@ + + + + net8.0 + Google.Cloud.Spanner.DataProvider.SpecificationTests + enable + enable + + false + true + Google.Cloud.Spanner.DataProvider.SpecificationTests + default + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + + + + + + + diff --git a/drivers/spanner-ado-net/spanner-ado-net-tests/AbstractMockServerTests.cs b/drivers/spanner-ado-net/spanner-ado-net-tests/AbstractMockServerTests.cs new file mode 100644 index 00000000..e5913b4a --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-tests/AbstractMockServerTests.cs @@ -0,0 +1,146 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Data.Common; +using Google.Cloud.SpannerLib.MockServer; + +namespace Google.Cloud.Spanner.DataProvider.Tests; + +public abstract class AbstractMockServerTests +{ + static AbstractMockServerTests() + { + AppDomain.CurrentDomain.ProcessExit += (_, _) => + { + SpannerPool.CloseSpannerLib(); + }; + } + + protected SpannerMockServerFixture Fixture; + + protected SpannerDataSource DataSource { get; private set; } + + + protected string ConnectionString => $"Host={Fixture.Host};Port={Fixture.Port};Data Source=projects/p1/instances/i1/databases/d1;UsePlainText=true"; + + [OneTimeSetUp] + public void Setup() + { + Fixture = new SpannerMockServerFixture(); + DataSource = SpannerDataSource.Create(ConnectionString); + } + + [OneTimeTearDown] + public void Teardown() + { + DataSource.Dispose(); + Fixture.Dispose(); + } + + [SetUp] + public void SetupResults() + { + Fixture.SpannerMock.AddOrUpdateStatementResult("SELECT 1", StatementResult.CreateSelect1ResultSet()); + } + + [TearDown] + public void Reset() + { + Fixture.SpannerMock.Reset(); + } + + protected SpannerConnection OpenConnection() + { + var connection = new SpannerConnection(ConnectionString); + connection.Open(); + return connection; + } + + protected async Task OpenConnectionAsync() + { + var connection = new SpannerConnection(ConnectionString); + await connection.OpenAsync(); + return connection; + } + + protected SpannerDataSource CreateDataSource() + { + return CreateDataSource(_ => { }); + } + + protected SpannerDataSource CreateDataSource(string connectionString) + { + return CreateDataSource(csb => { csb.ConnectionString = connectionString; }); + } + + protected SpannerDataSource CreateDataSource(Action connectionStringBuilderAction) + { + var connectionStringBuilder = new SpannerConnectionStringBuilder(ConnectionString); + connectionStringBuilderAction(connectionStringBuilder); + return SpannerDataSource.Create(connectionStringBuilder); + } + +} + +public static class SpannerConnectionExtensions +{ + public static int ExecuteNonQuery(this SpannerConnection conn, string sql, SpannerTransaction? tx = null) + { + using var command = tx == null ? new SpannerCommand(sql, conn) : new SpannerCommand(sql, conn, tx); + return command.ExecuteNonQuery(); + } + + public static object? ExecuteScalar(this SpannerConnection conn, string sql, SpannerTransaction? tx = null) + { + using var command = tx == null ? new SpannerCommand(sql, conn) : new SpannerCommand(sql, conn, tx); + return command.ExecuteScalar(); + } + + public static async Task ExecuteNonQueryAsync( + this SpannerConnection conn, string sql, SpannerTransaction? tx = null, CancellationToken cancellationToken = default) + { + await using var command = tx == null ? new SpannerCommand(sql, conn) : new SpannerCommand(sql, conn, tx); + return await command.ExecuteNonQueryAsync(cancellationToken); + } + + public static async Task ExecuteScalarAsync( + this SpannerConnection conn, string sql, SpannerTransaction? tx = null, CancellationToken cancellationToken = default) + { + await using var command = tx == null ? new SpannerCommand(sql, conn) : new SpannerCommand(sql, conn, tx); + return await command.ExecuteScalarAsync(cancellationToken); + } +} + +public static class SpannerCommandExtensions +{ + internal static void AddParameter(this SpannerCommand command, string name, object? value) + { + var parameter = command.CreateParameter(); + parameter.ParameterName = name; + parameter.Value = value; + command.Parameters.Add(parameter); + } +} + +public static class BatchExtensions +{ + internal static void AddSpannerBatchCommand(this DbBatch batch, string sql) + { + var command = new SpannerBatchCommand + { + CommandText = sql + }; + batch.BatchCommands.Add(command); + } +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-tests/BasicTests.cs b/drivers/spanner-ado-net/spanner-ado-net-tests/BasicTests.cs new file mode 100644 index 00000000..6962038a --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-tests/BasicTests.cs @@ -0,0 +1,123 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Data.Common; +using System.Text.Json; +using Google.Cloud.Spanner.V1; +using Google.Cloud.SpannerLib.MockServer; + +namespace Google.Cloud.Spanner.DataProvider.Tests; + +public class BasicTests : AbstractMockServerTests +{ + [Test] + public void TestOpenConnection() + { + var connection = new SpannerConnection(); + connection.ConnectionString = ConnectionString; + connection.Open(); + connection.Close(); + } + + [Test] + public void TestExecuteQuery() + { + using var connection = new SpannerConnection(); + connection.ConnectionString = ConnectionString; + connection.Open(); + + using var cmd = connection.CreateCommand(); + cmd.CommandText = "SELECT 1"; + using var reader = cmd.ExecuteReader(); + while (reader.Read()) + { + Assert.That(reader.GetInt64(0), Is.EqualTo(1)); + } + } + + [Test] + public void TestExecuteParameterizedQuery() + { + Fixture.SpannerMock.AddOrUpdateStatementResult("SELECT $1", StatementResult.CreateSelect1ResultSet()); + using var connection = new SpannerConnection(); + connection.ConnectionString = ConnectionString; + connection.Open(); + + using var cmd = connection.CreateCommand(); + cmd.CommandText = "SELECT $1"; + var param = cmd.CreateParameter(); + param.ParameterName = "p1"; + param.Value = 1; + cmd.Parameters.Add(param); + using var reader = cmd.ExecuteReader(); + while (reader.Read()) + { + Assert.That(reader.GetInt64(0), Is.EqualTo(1)); + } + } + + [Test] + public void TestInsertAllDataTypes() + { + var sql = "insert into all_types (col_bool, col_bytes, col_date, col_interval, col_json, col_int64, col_float32, col_float64, col_numeric, col_string, col_timestamp) " + + "values (@col_bool, @col_bytes, @col_date, @col_interval, @col_json, @col_int64, @col_float32, @col_float64, @col_numeric, @col_string, @col_timestamp)"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateUpdateCount(1)); + + using var connection = new SpannerConnection(); + connection.ConnectionString = ConnectionString; + connection.Open(); + + using var cmd = connection.CreateCommand(); + cmd.CommandText = sql; + AddParameter(cmd, "col_bool", true); + AddParameter(cmd, "col_bytes", new byte[] { 1, 2, 3 }); + AddParameter(cmd, "col_date", new DateOnly(2025, 8, 25)); + AddParameter(cmd, "col_interval", TimeSpan.FromHours(1)); + AddParameter(cmd, "col_json", JsonDocument.Parse("{\"key\":\"value\"}")); + AddParameter(cmd, "col_int64", 10); + AddParameter(cmd, "col_float32", 3.14f); + AddParameter(cmd, "col_float64", 3.14d); + AddParameter(cmd, "col_numeric", 10.1m); + AddParameter(cmd, "col_string", "hello"); + AddParameter(cmd, "col_timestamp", DateTime.Parse("2025-08-25T16:30:55Z")); + + var updateCount = cmd.ExecuteNonQuery(); + Assert.That(updateCount, Is.EqualTo(1)); + + var requests = Fixture.SpannerMock.Requests.OfType().ToList(); + Assert.That(requests, Has.Count.EqualTo(1)); + var request = requests.First(); + Assert.That(request.Params.Fields, Has.Count.EqualTo(11)); + Assert.That(request.Params.Fields["col_bool"].BoolValue, Is.EqualTo(true)); + Assert.That(request.Params.Fields["col_bytes"].StringValue, Is.EqualTo(Convert.ToBase64String(new byte[]{1,2,3}))); + Assert.That(request.Params.Fields["col_date"].StringValue, Is.EqualTo("2025-08-25")); + Assert.That(request.Params.Fields["col_interval"].StringValue, Is.EqualTo("PT1H")); + Assert.That(request.Params.Fields["col_int64"].StringValue, Is.EqualTo("10")); + Assert.That(request.Params.Fields["col_float32"].NumberValue, Is.EqualTo(3.14f)); + Assert.That(request.Params.Fields["col_float64"].NumberValue, Is.EqualTo(3.14d)); + Assert.That(request.Params.Fields["col_numeric"].StringValue, Is.EqualTo("10.1")); + Assert.That(request.Params.Fields["col_string"].StringValue, Is.EqualTo("hello")); + Assert.That(request.Params.Fields["col_timestamp"].StringValue, Is.EqualTo("2025-08-25T16:30:55.0000000Z")); + + Assert.That(request.ParamTypes.Count, Is.EqualTo(0)); + } + + private void AddParameter(DbCommand command, string name, object value) + { + var param = command.CreateParameter(); + param.ParameterName = name; + param.Value = value; + command.Parameters.Add(param); + } +} diff --git a/drivers/spanner-ado-net/spanner-ado-net-tests/BatchTests.cs b/drivers/spanner-ado-net/spanner-ado-net-tests/BatchTests.cs new file mode 100644 index 00000000..6590020c --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-tests/BatchTests.cs @@ -0,0 +1,272 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Data.Common; +using System.Text.Json; +using Google.Cloud.Spanner.V1; +using Google.Cloud.SpannerLib.MockServer; +using Google.Protobuf.WellKnownTypes; + +namespace Google.Cloud.Spanner.DataProvider.Tests; + +public class BatchTests : AbstractMockServerTests +{ + [TestCase(1, false)] + [TestCase(2, false)] + [TestCase(5, false)] + [TestCase(1, true)] + [TestCase(2, true)] + [TestCase(5, true)] + public async Task TestAllParameterTypes(int numCommands, bool executeAsync) + { + await using var connection = new SpannerConnection(); + connection.ConnectionString = ConnectionString; + await connection.OpenAsync(); + + const string insert = "insert into my_table values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11, @p12)"; + Fixture.SpannerMock.AddOrUpdateStatementResult(insert, StatementResult.CreateUpdateCount(1)); + + await using var batch = connection.CreateBatch(); + + for (var i = 0; i < numCommands; i++) + { + var command = batch.CreateBatchCommand(); + command.CommandText = insert; + // TODO: + // - PROTO + // - STRUCT + AddParameter(command, "p1", true); + AddParameter(command, "p2", new byte[] { 1, 2, 3 }); + AddParameter(command, "p3", new DateOnly(2025, 10, 2)); + AddParameter(command, "p4", new TimeSpan(1, 2, 3, 4, 5, 6)); + AddParameter(command, "p5", JsonDocument.Parse("{\"key\": \"value\"}")); + AddParameter(command, "p6", 9.99m); + AddParameter(command, "p7", "test"); + AddParameter(command, "p8", new DateTime(2025, 10, 2, 15, 57, 31, 999, DateTimeKind.Utc)); + AddParameter(command, "p9", Guid.Parse("5555990c-b259-4539-bd22-5a9293cf10ac")); + AddParameter(command, "p10", 3.14d); + AddParameter(command, "p11", 3.14f); + AddParameter(command, "p12", DBNull.Value); + + AddParameter(command, "p13", new bool?[] { true, false, null }); + AddParameter(command, "p14", new byte[]?[] { [1, 2, 3], null }); + AddParameter(command, "p15", new DateOnly?[] { new DateOnly(2025, 10, 2), null }); + AddParameter(command, "p16", new TimeSpan?[] { new TimeSpan(1, 2, 3, 4, 5, 6), null }); + AddParameter(command, "p17", new[] { JsonDocument.Parse("{\"key\": \"value\"}"), null }); + AddParameter(command, "p18", new decimal?[] { 9.99m, null }); + AddParameter(command, "p19", new[] { "test", null }); + AddParameter(command, "p20", + new DateTime?[] { new DateTime(2025, 10, 2, 15, 57, 31, 999, DateTimeKind.Utc), null }); + AddParameter(command, "p21", new Guid?[] { Guid.Parse("5555990c-b259-4539-bd22-5a9293cf10ac"), null }); + AddParameter(command, "p22", new double?[] { 3.14d, null }); + AddParameter(command, "p23", new float?[] { 3.14f, null }); + + batch.BatchCommands.Add(command); + } + + int affected; + if (executeAsync) + { + affected = await batch.ExecuteNonQueryAsync(); + } + else + { + // ReSharper disable once MethodHasAsyncOverload + affected = batch.ExecuteNonQuery(); + } + Assert.That(affected, Is.EqualTo(numCommands)); + foreach (var command in batch.BatchCommands) + { + Assert.That(command.RecordsAffected, Is.EqualTo(1)); + } + + var requests = Fixture.SpannerMock.Requests.ToList(); + Assert.That(requests.OfType().Count, Is.EqualTo(1)); + Assert.That(requests.OfType().Count, Is.EqualTo(1)); + var request = requests.OfType().Single(); + Assert.That(request.Statements.Count, Is.EqualTo(numCommands)); + foreach (var statement in request.Statements) + { + // The driver does not send any parameter types, unless it is explicitly asked to do so. + Assert.That(statement.ParamTypes.Count, Is.EqualTo(0)); + Assert.That(statement.Params.Fields.Count, Is.EqualTo(23)); + var fields = statement.Params.Fields; + Assert.That(fields["p1"].HasBoolValue, Is.True); + Assert.That(fields["p1"].BoolValue, Is.True); + Assert.That(fields["p2"].HasStringValue, Is.True); + Assert.That(fields["p2"].StringValue, Is.EqualTo(Convert.ToBase64String(new byte[] { 1, 2, 3 }))); + Assert.That(fields["p3"].HasStringValue, Is.True); + Assert.That(fields["p3"].StringValue, Is.EqualTo("2025-10-02")); + Assert.That(fields["p4"].HasStringValue, Is.True); + Assert.That(fields["p4"].StringValue, Is.EqualTo("P1DT2H3M4.005006S")); + Assert.That(fields["p5"].HasStringValue, Is.True); + Assert.That(fields["p5"].StringValue, Is.EqualTo("{\"key\": \"value\"}")); + Assert.That(fields["p6"].HasStringValue, Is.True); + Assert.That(fields["p6"].StringValue, Is.EqualTo("9.99")); + Assert.That(fields["p7"].HasStringValue, Is.True); + Assert.That(fields["p7"].StringValue, Is.EqualTo("test")); + Assert.That(fields["p8"].HasStringValue, Is.True); + Assert.That(fields["p8"].StringValue, Is.EqualTo("2025-10-02T15:57:31.9990000Z")); + Assert.That(fields["p9"].HasStringValue, Is.True); + Assert.That(fields["p9"].StringValue, Is.EqualTo("5555990c-b259-4539-bd22-5a9293cf10ac")); + Assert.That(fields["p10"].HasNumberValue, Is.True); + Assert.That(fields["p10"].NumberValue, Is.EqualTo(3.14d)); + Assert.That(fields["p11"].HasNumberValue, Is.True); + Assert.That(fields["p11"].NumberValue, Is.EqualTo(3.14f)); + Assert.That(fields["p12"].HasNullValue, Is.True); + + Assert.That(fields["p13"].KindCase, Is.EqualTo(Value.KindOneofCase.ListValue)); + Assert.That(fields["p13"].ListValue.Values.Count, Is.EqualTo(3)); + Assert.That(fields["p13"].ListValue.Values[0].HasBoolValue, Is.True); + Assert.That(fields["p13"].ListValue.Values[0].BoolValue, Is.True); + Assert.That(fields["p13"].ListValue.Values[1].HasBoolValue, Is.True); + Assert.That(fields["p13"].ListValue.Values[1].BoolValue, Is.False); + Assert.That(fields["p13"].ListValue.Values[2].HasNullValue, Is.True); + + Assert.That(fields["p14"].KindCase, Is.EqualTo(Value.KindOneofCase.ListValue)); + Assert.That(fields["p14"].ListValue.Values.Count, Is.EqualTo(2)); + Assert.That(fields["p14"].ListValue.Values[0].HasStringValue, Is.True); + Assert.That(fields["p14"].ListValue.Values[0].StringValue, + Is.EqualTo(Convert.ToBase64String(new byte[] { 1, 2, 3 }))); + Assert.That(fields["p14"].ListValue.Values[1].HasNullValue, Is.True); + + Assert.That(fields["p15"].KindCase, Is.EqualTo(Value.KindOneofCase.ListValue)); + Assert.That(fields["p15"].ListValue.Values.Count, Is.EqualTo(2)); + Assert.That(fields["p15"].ListValue.Values[0].HasStringValue, Is.True); + Assert.That(fields["p15"].ListValue.Values[0].StringValue, Is.EqualTo("2025-10-02")); + Assert.That(fields["p15"].ListValue.Values[1].HasNullValue, Is.True); + + Assert.That(fields["p16"].KindCase, Is.EqualTo(Value.KindOneofCase.ListValue)); + Assert.That(fields["p16"].ListValue.Values.Count, Is.EqualTo(2)); + Assert.That(fields["p16"].ListValue.Values[0].HasStringValue, Is.True); + Assert.That(fields["p16"].ListValue.Values[0].StringValue, Is.EqualTo("P1DT2H3M4.005006S")); + Assert.That(fields["p16"].ListValue.Values[1].HasNullValue, Is.True); + + Assert.That(fields["p17"].KindCase, Is.EqualTo(Value.KindOneofCase.ListValue)); + Assert.That(fields["p17"].ListValue.Values.Count, Is.EqualTo(2)); + Assert.That(fields["p17"].ListValue.Values[0].HasStringValue, Is.True); + Assert.That(fields["p17"].ListValue.Values[0].StringValue, Is.EqualTo("{\"key\": \"value\"}")); + Assert.That(fields["p17"].ListValue.Values[1].HasNullValue, Is.True); + + Assert.That(fields["p18"].KindCase, Is.EqualTo(Value.KindOneofCase.ListValue)); + Assert.That(fields["p18"].ListValue.Values.Count, Is.EqualTo(2)); + Assert.That(fields["p18"].ListValue.Values[0].HasStringValue, Is.True); + Assert.That(fields["p18"].ListValue.Values[0].StringValue, Is.EqualTo("9.99")); + Assert.That(fields["p18"].ListValue.Values[1].HasNullValue, Is.True); + + Assert.That(fields["p19"].KindCase, Is.EqualTo(Value.KindOneofCase.ListValue)); + Assert.That(fields["p19"].ListValue.Values.Count, Is.EqualTo(2)); + Assert.That(fields["p19"].ListValue.Values[0].HasStringValue, Is.True); + Assert.That(fields["p19"].ListValue.Values[0].StringValue, Is.EqualTo("test")); + Assert.That(fields["p19"].ListValue.Values[1].HasNullValue, Is.True); + + Assert.That(fields["p20"].KindCase, Is.EqualTo(Value.KindOneofCase.ListValue)); + Assert.That(fields["p20"].ListValue.Values.Count, Is.EqualTo(2)); + Assert.That(fields["p20"].ListValue.Values[0].HasStringValue, Is.True); + Assert.That(fields["p20"].ListValue.Values[0].StringValue, Is.EqualTo("2025-10-02T15:57:31.9990000Z")); + Assert.That(fields["p20"].ListValue.Values[1].HasNullValue, Is.True); + + Assert.That(fields["p21"].KindCase, Is.EqualTo(Value.KindOneofCase.ListValue)); + Assert.That(fields["p21"].ListValue.Values.Count, Is.EqualTo(2)); + Assert.That(fields["p21"].ListValue.Values[0].HasStringValue, Is.True); + Assert.That(fields["p21"].ListValue.Values[0].StringValue, + Is.EqualTo("5555990c-b259-4539-bd22-5a9293cf10ac")); + Assert.That(fields["p21"].ListValue.Values[1].HasNullValue, Is.True); + + Assert.That(fields["p22"].KindCase, Is.EqualTo(Value.KindOneofCase.ListValue)); + Assert.That(fields["p22"].ListValue.Values.Count, Is.EqualTo(2)); + Assert.That(fields["p22"].ListValue.Values[0].HasNumberValue, Is.True); + Assert.That(fields["p22"].ListValue.Values[0].NumberValue, Is.EqualTo(3.14d)); + Assert.That(fields["p22"].ListValue.Values[1].HasNullValue, Is.True); + + Assert.That(fields["p23"].KindCase, Is.EqualTo(Value.KindOneofCase.ListValue)); + Assert.That(fields["p23"].ListValue.Values.Count, Is.EqualTo(2)); + Assert.That(fields["p23"].ListValue.Values[0].HasNumberValue, Is.True); + Assert.That(fields["p23"].ListValue.Values[0].NumberValue, Is.EqualTo(3.14f)); + Assert.That(fields["p23"].ListValue.Values[1].HasNullValue, Is.True); + } + } + + [TestCase(true)] + [TestCase(false)] + public async Task TestEmptyBatch(bool executeAsync) + { + await using var connection = new SpannerConnection(); + connection.ConnectionString = ConnectionString; + await connection.OpenAsync(); + + await using var batch = connection.CreateBatch(); + int affected; + if (executeAsync) + { + affected = await batch.ExecuteNonQueryAsync(); + } + else + { + // ReSharper disable once MethodHasAsyncOverload + affected = batch.ExecuteNonQuery(); + } + Assert.That(affected, Is.EqualTo(0)); + } + + [TestCase(true)] + [TestCase(false)] + public async Task TestExecuteReader(bool executeAsync) + { + await using var connection = new SpannerConnection(); + connection.ConnectionString = ConnectionString; + await connection.OpenAsync(); + + await using var batch = connection.CreateBatch(); + var command = batch.CreateBatchCommand(); + command.CommandText = "select * from my_table"; + if (executeAsync) + { + Assert.ThrowsAsync(() => batch.ExecuteReaderAsync()); + } + else + { + Assert.Throws(() => batch.ExecuteReader()); + } + } + + [TestCase(true)] + [TestCase(false)] + public async Task TestExecuteScalar(bool executeAsync) + { + await using var connection = new SpannerConnection(); + connection.ConnectionString = ConnectionString; + await connection.OpenAsync(); + + await using var batch = connection.CreateBatch(); + var command = batch.CreateBatchCommand(); + command.CommandText = "select * from my_table"; + if (executeAsync) + { + Assert.ThrowsAsync(() => batch.ExecuteScalarAsync()); + } + else + { + Assert.Throws(() => batch.ExecuteScalar()); + } + } + + private static void AddParameter(DbBatchCommand command, string name, object? value) + { + var parameter = command.CreateParameter(); + parameter.ParameterName = name; + parameter.Value = value; + command.Parameters.Add(parameter); + } +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-tests/CommandParameterTests.cs b/drivers/spanner-ado-net/spanner-ado-net-tests/CommandParameterTests.cs new file mode 100644 index 00000000..d1ec8775 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-tests/CommandParameterTests.cs @@ -0,0 +1,218 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Data; +using Google.Cloud.Spanner.Admin.Database.V1; +using Google.Cloud.Spanner.V1; +using Google.Cloud.SpannerLib.MockServer; +using TypeCode = Google.Cloud.Spanner.V1.TypeCode; + +namespace Google.Cloud.Spanner.DataProvider.Tests; + +public class CommandParameterTests : AbstractMockServerTests +{ + [Test] + [TestCase(CommandBehavior.Default)] + [TestCase(CommandBehavior.SequentialAccess)] + public async Task InputAndOutputParameters(CommandBehavior behavior) + { + const string sql = "SELECT @c-1 AS c, @a+2 AS b"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + new List>([ + Tuple.Create(TypeCode.Int64, "c"), + Tuple.Create(TypeCode.Int64, "b"), + ]), + new List([[3, 5]]))); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(sql, conn); + cmd.AddParameter("a", 3); + var b = new SpannerParameter { ParameterName = "b", Direction = ParameterDirection.Output }; + cmd.Parameters.Add(b); + var c = new SpannerParameter { ParameterName = "c", Direction = ParameterDirection.InputOutput, Value = 4 }; + cmd.Parameters.Add(c); + await using (await cmd.ExecuteReaderAsync(behavior)) + { + // TODO: Enable if we decide to support output parameters in the same way as npgsql. + // Assert.That(b.Value, Is.EqualTo(5)); + // Assert.That(c.Value, Is.EqualTo(3)); + } + var request = Fixture.SpannerMock.Requests.Single(r => r is ExecuteSqlRequest { Sql: sql }) as ExecuteSqlRequest; + Assert.That(request, Is.Not.Null); + Assert.That(request.Params.Fields.Count, Is.EqualTo(3)); + Assert.That(request.Params.Fields["a"].StringValue, Is.EqualTo("3")); + Assert.That(request.Params.Fields["b"].HasNullValue); + Assert.That(request.Params.Fields["c"].StringValue, Is.EqualTo("4")); + } + + [Test] + public async Task SendWithoutType([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) + { + const string sql = "select cast(@p as timestamp)"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateSingleColumnResultSet( + new V1.Type{Code = TypeCode.Timestamp}, "p", "2025-10-30T10:00:00.000000000Z")); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(sql, conn); + cmd.AddParameter("p", "2025-10-30T10:00:00Z"); + if (prepare == PrepareOrNot.Prepared) + { + await cmd.PrepareAsync(); + } + + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + Assert.That(reader.GetValue(0), Is.EqualTo(new DateTime(2025, 10, 30, 10, 0, 0, DateTimeKind.Utc))); + + var request = Fixture.SpannerMock.Requests.First(r => r is ExecuteSqlRequest { Sql: sql }) as ExecuteSqlRequest; + Assert.That(request, Is.Not.Null); + Assert.That(request.Params.Fields.Count, Is.EqualTo(1)); + Assert.That(request.Params.Fields["p"].StringValue, Is.EqualTo("2025-10-30T10:00:00Z")); + Assert.That(request.ParamTypes.Count, Is.EqualTo(0)); + + var expectedCount = prepare == PrepareOrNot.Prepared ? 2 : 1; + Assert.That(Fixture.SpannerMock.Requests.Count(r => r is ExecuteSqlRequest { Sql: sql }), Is.EqualTo(expectedCount)); + } + + [Test] + public async Task PositionalParameter() + { + // Set the database dialect to PostgreSQL to enable the use of PostgreSQL-style positional parameters. + Fixture.SpannerMock.AddDialectResult(DatabaseDialect.Postgresql); + const string sql = "SELECT $1"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateSingleColumnResultSet( + new V1.Type{Code = TypeCode.Int64}, "c", 8L)); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(sql, conn); + cmd.Parameters.Add(new SpannerParameter { Value = 8 }); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(8)); + + var request = Fixture.SpannerMock.Requests.Single(r => r is ExecuteSqlRequest { Sql: sql }) as ExecuteSqlRequest; + Assert.That(request, Is.Not.Null); + Assert.That(request.Params.Fields.Count, Is.EqualTo(1)); + Assert.That(request.Params.Fields["p1"].StringValue, Is.EqualTo("8")); + Assert.That(request.ParamTypes.Count, Is.EqualTo(0)); + } + + [Test] + public async Task UnreferencedNamedParameterIsIgnored() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand("SELECT 1", conn); + cmd.AddParameter("not_used", 8); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(1)); + + var request = Fixture.SpannerMock.Requests.Single(r => r is ExecuteSqlRequest { Sql: "SELECT 1" }) as ExecuteSqlRequest; + Assert.That(request, Is.Not.Null); + Assert.That(request.Params.Fields.Count, Is.EqualTo(1)); + Assert.That(request.Params.Fields["not_used"].StringValue, Is.EqualTo("8")); + Assert.That(request.ParamTypes.Count, Is.EqualTo(0)); + } + + [Test] + public async Task UnreferencedPositionalParameterIsIgnored() + { + // Set the database dialect to PostgreSQL to enable the use of PostgreSQL-style positional parameters. + Fixture.SpannerMock.AddDialectResult(DatabaseDialect.Postgresql); + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand("SELECT 1", conn); + cmd.Parameters.Add(new SpannerParameter { Value = 8 }); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(1)); + + var request = Fixture.SpannerMock.Requests.Single(r => r is ExecuteSqlRequest { Sql: "SELECT 1" }) as ExecuteSqlRequest; + Assert.That(request, Is.Not.Null); + Assert.That(request.Params.Fields.Count, Is.EqualTo(1)); + Assert.That(request.Params.Fields["p1"].StringValue, Is.EqualTo("8")); + Assert.That(request.ParamTypes.Count, Is.EqualTo(0)); + } + + [Test] + public void ParameterName() + { + var command = new SpannerCommand(); + + // Add parameters. + command.Parameters.Add(new SpannerParameter{ ParameterName = "@Parameter1", DbType = DbType.Boolean, Value = true }); + command.Parameters.Add(new SpannerParameter{ ParameterName = "@Parameter2", DbType = DbType.Int32, Value = 1 }); + command.Parameters.Add(new SpannerParameter{ ParameterName = "Parameter3", DbType = DbType.DateTime, Value = DBNull.Value }); + command.Parameters.Add(new SpannerParameter{ ParameterName = "Parameter4", DbType = DbType.Binary, Value = DBNull.Value }); + + var parameter = command.Parameters["@Parameter1"]; + Assert.That(parameter, Is.Not.Null); + command.Parameters[0].Value = 1; + + Assert.That(command.Parameters["@Parameter1"].ParameterName, Is.EqualTo("@Parameter1")); + Assert.That(command.Parameters["@Parameter2"].ParameterName, Is.EqualTo("@Parameter2")); + Assert.That(command.Parameters["Parameter3"].ParameterName, Is.EqualTo("Parameter3")); + Assert.That(command.Parameters["Parameter4"].ParameterName, Is.EqualTo("Parameter4")); + + Assert.That(command.Parameters[0].ParameterName, Is.EqualTo("@Parameter1")); + Assert.That(command.Parameters[1].ParameterName, Is.EqualTo("@Parameter2")); + Assert.That(command.Parameters[2].ParameterName, Is.EqualTo("Parameter3")); + Assert.That(command.Parameters[3].ParameterName, Is.EqualTo("Parameter4")); + + // Verify that the '@' is stripped before being sent to Spanner. + var statement = command.BuildStatement(); + Assert.That(statement, Is.Not.Null); + Assert.That(statement.Params.Fields.Count, Is.EqualTo(4)); + Assert.That(statement.Params.Fields["Parameter1"].StringValue, Is.EqualTo("1")); + Assert.That(statement.Params.Fields["Parameter2"].StringValue, Is.EqualTo("1")); + Assert.That(statement.Params.Fields["Parameter3"].HasNullValue); + Assert.That(statement.Params.Fields["Parameter4"].HasNullValue); + + Assert.That(statement.ParamTypes.Count, Is.EqualTo(4)); + Assert.That(statement.ParamTypes["Parameter1"].Code, Is.EqualTo(TypeCode.Bool)); + Assert.That(statement.ParamTypes["Parameter2"].Code, Is.EqualTo(TypeCode.Int64)); + Assert.That(statement.ParamTypes["Parameter3"].Code, Is.EqualTo(TypeCode.Timestamp)); + Assert.That(statement.ParamTypes["Parameter4"].Code, Is.EqualTo(TypeCode.Bytes)); + } + + [Test] + public async Task SameParamMultipleTimes() + { + const string sql = "SELECT @p1, @p1"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + new List>([Tuple.Create(TypeCode.Int64, "p1"), Tuple.Create(TypeCode.Int64, "p1")]), + new List([[8, 8]]))); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(sql, conn); + cmd.AddParameter("@p1", 8); + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + Assert.That(reader[0], Is.EqualTo(8)); + Assert.That(reader[1], Is.EqualTo(8)); + + var request = Fixture.SpannerMock.Requests.Single(r => r is ExecuteSqlRequest { Sql: sql }) as ExecuteSqlRequest; + Assert.That(request, Is.Not.Null); + Assert.That(request.Params.Fields.Count, Is.EqualTo(1)); + Assert.That(request.Params.Fields["p1"].StringValue, Is.EqualTo("8")); + Assert.That(request.ParamTypes.Count, Is.EqualTo(0)); + } + + [Test] + public async Task ParameterMustBeSet() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand("SELECT @p1::TEXT", conn); + cmd.Parameters.Add(new SpannerParameter{ ParameterName = "@p1" }); + + Assert.That(async () => await cmd.ExecuteReaderAsync(), + Throws.Exception + .TypeOf() + .With.Message.EqualTo("Parameter @p1 has no value")); + } + +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-tests/CommandTests.cs b/drivers/spanner-ado-net/spanner-ado-net-tests/CommandTests.cs new file mode 100644 index 00000000..7bee948b --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-tests/CommandTests.cs @@ -0,0 +1,890 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Data; +using System.Data.Common; +using System.Diagnostics.CodeAnalysis; +using System.Text; +using System.Text.Json; +using Google.Cloud.Spanner.V1; +using Google.Cloud.SpannerLib.MockServer; +using Google.Protobuf.WellKnownTypes; +using Grpc.Core; +using Status = Grpc.Core.Status; +using TypeCode = Google.Cloud.Spanner.V1.TypeCode; + +namespace Google.Cloud.Spanner.DataProvider.Tests; + +public class CommandTests : AbstractMockServerTests +{ + [Test] + public async Task TestAllParameterTypes() + { + var insert = "insert into my_table values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11, @p12)"; + Fixture.SpannerMock.AddOrUpdateStatementResult(insert, StatementResult.CreateUpdateCount(1)); + + await using var connection = new SpannerConnection(); + connection.ConnectionString = ConnectionString; + await connection.OpenAsync(); + + await using var command = connection.CreateCommand(); + command.CommandText = insert; + // TODO: + // - PROTO + // - STRUCT + AddParameter(command, "p1", true); + AddParameter(command, "p2", new byte[] {1, 2, 3}); + AddParameter(command, "p3", new DateOnly(2025, 10, 2)); + AddParameter(command, "p4", new TimeSpan(1, 2, 3, 4, 5, 6)); + AddParameter(command, "p5", JsonDocument.Parse("{\"key\": \"value\"}")); + AddParameter(command, "p6", 9.99m); + AddParameter(command, "p7", "test"); + AddParameter(command, "p8", new DateTime(2025, 10, 2, 15, 57, 31, 999, DateTimeKind.Utc)); + AddParameter(command, "p9", Guid.Parse("5555990c-b259-4539-bd22-5a9293cf10ac")); + AddParameter(command, "p10", 3.14d); + AddParameter(command, "p11", 3.14f); + AddParameter(command, "p12", DBNull.Value); + + AddParameter(command, "p13", new bool?[]{true, false, null}); + AddParameter(command, "p14", new byte[]?[]{ [1,2,3], null }); + AddParameter(command, "p15", new DateOnly?[] { new DateOnly(2025, 10, 2), null }); + AddParameter(command, "p16", new TimeSpan?[] { new TimeSpan(1, 2, 3, 4, 5, 6), null }); + AddParameter(command, "p17", new [] { JsonDocument.Parse("{\"key\": \"value\"}"), null }); + AddParameter(command, "p18", new decimal?[] { 9.99m, null }); + AddParameter(command, "p19", new [] { "test", null }); + AddParameter(command, "p20", new DateTime?[] { new DateTime(2025, 10, 2, 15, 57, 31, 999, DateTimeKind.Utc), null }); + AddParameter(command, "p21", new Guid?[] { Guid.Parse("5555990c-b259-4539-bd22-5a9293cf10ac"), null }); + AddParameter(command, "p22", new double?[] { 3.14d, null }); + AddParameter(command, "p23", new float?[] { 3.14f, null }); + + await command.ExecuteNonQueryAsync(); + + var requests = Fixture.SpannerMock.Requests.ToList(); + Assert.That(requests.OfType().Count, Is.EqualTo(1)); + Assert.That(requests.OfType().Count, Is.EqualTo(1)); + var request = requests.OfType().Single(); + // The driver does not send any parameter types, unless it is explicitly asked to do so. + Assert.That(request.ParamTypes.Count, Is.EqualTo(0)); + Assert.That(request.Params.Fields.Count, Is.EqualTo(23)); + var fields = request.Params.Fields; + Assert.That(fields["p1"].HasBoolValue, Is.True); + Assert.That(fields["p1"].BoolValue, Is.True); + Assert.That(fields["p2"].HasStringValue, Is.True); + Assert.That(fields["p2"].StringValue, Is.EqualTo(Convert.ToBase64String(new byte[]{1,2,3}))); + Assert.That(fields["p3"].HasStringValue, Is.True); + Assert.That(fields["p3"].StringValue, Is.EqualTo("2025-10-02")); + Assert.That(fields["p4"].HasStringValue, Is.True); + Assert.That(fields["p4"].StringValue, Is.EqualTo("P1DT2H3M4.005006S")); + Assert.That(fields["p5"].HasStringValue, Is.True); + Assert.That(fields["p5"].StringValue, Is.EqualTo("{\"key\": \"value\"}")); + Assert.That(fields["p6"].HasStringValue, Is.True); + Assert.That(fields["p6"].StringValue, Is.EqualTo("9.99")); + Assert.That(fields["p7"].HasStringValue, Is.True); + Assert.That(fields["p7"].StringValue, Is.EqualTo("test")); + Assert.That(fields["p8"].HasStringValue, Is.True); + Assert.That(fields["p8"].StringValue, Is.EqualTo("2025-10-02T15:57:31.9990000Z")); + Assert.That(fields["p9"].HasStringValue, Is.True); + Assert.That(fields["p9"].StringValue, Is.EqualTo("5555990c-b259-4539-bd22-5a9293cf10ac")); + Assert.That(fields["p10"].HasNumberValue, Is.True); + Assert.That(fields["p10"].NumberValue, Is.EqualTo(3.14d)); + Assert.That(fields["p11"].HasNumberValue, Is.True); + Assert.That(fields["p11"].NumberValue, Is.EqualTo(3.14f)); + Assert.That(fields["p12"].HasNullValue, Is.True); + + Assert.That(fields["p13"].KindCase, Is.EqualTo(Value.KindOneofCase.ListValue)); + Assert.That(fields["p13"].ListValue.Values.Count, Is.EqualTo(3)); + Assert.That(fields["p13"].ListValue.Values[0].HasBoolValue, Is.True); + Assert.That(fields["p13"].ListValue.Values[0].BoolValue, Is.True); + Assert.That(fields["p13"].ListValue.Values[1].HasBoolValue, Is.True); + Assert.That(fields["p13"].ListValue.Values[1].BoolValue, Is.False); + Assert.That(fields["p13"].ListValue.Values[2].HasNullValue, Is.True); + + Assert.That(fields["p14"].KindCase, Is.EqualTo(Value.KindOneofCase.ListValue)); + Assert.That(fields["p14"].ListValue.Values.Count, Is.EqualTo(2)); + Assert.That(fields["p14"].ListValue.Values[0].HasStringValue, Is.True); + Assert.That(fields["p14"].ListValue.Values[0].StringValue, Is.EqualTo(Convert.ToBase64String(new byte[]{1,2,3}))); + Assert.That(fields["p14"].ListValue.Values[1].HasNullValue, Is.True); + + Assert.That(fields["p15"].KindCase, Is.EqualTo(Value.KindOneofCase.ListValue)); + Assert.That(fields["p15"].ListValue.Values.Count, Is.EqualTo(2)); + Assert.That(fields["p15"].ListValue.Values[0].HasStringValue, Is.True); + Assert.That(fields["p15"].ListValue.Values[0].StringValue, Is.EqualTo("2025-10-02")); + Assert.That(fields["p15"].ListValue.Values[1].HasNullValue, Is.True); + + Assert.That(fields["p16"].KindCase, Is.EqualTo(Value.KindOneofCase.ListValue)); + Assert.That(fields["p16"].ListValue.Values.Count, Is.EqualTo(2)); + Assert.That(fields["p16"].ListValue.Values[0].HasStringValue, Is.True); + Assert.That(fields["p16"].ListValue.Values[0].StringValue, Is.EqualTo("P1DT2H3M4.005006S")); + Assert.That(fields["p16"].ListValue.Values[1].HasNullValue, Is.True); + + Assert.That(fields["p17"].KindCase, Is.EqualTo(Value.KindOneofCase.ListValue)); + Assert.That(fields["p17"].ListValue.Values.Count, Is.EqualTo(2)); + Assert.That(fields["p17"].ListValue.Values[0].HasStringValue, Is.True); + Assert.That(fields["p17"].ListValue.Values[0].StringValue, Is.EqualTo("{\"key\": \"value\"}")); + Assert.That(fields["p17"].ListValue.Values[1].HasNullValue, Is.True); + + Assert.That(fields["p18"].KindCase, Is.EqualTo(Value.KindOneofCase.ListValue)); + Assert.That(fields["p18"].ListValue.Values.Count, Is.EqualTo(2)); + Assert.That(fields["p18"].ListValue.Values[0].HasStringValue, Is.True); + Assert.That(fields["p18"].ListValue.Values[0].StringValue, Is.EqualTo("9.99")); + Assert.That(fields["p18"].ListValue.Values[1].HasNullValue, Is.True); + + Assert.That(fields["p19"].KindCase, Is.EqualTo(Value.KindOneofCase.ListValue)); + Assert.That(fields["p19"].ListValue.Values.Count, Is.EqualTo(2)); + Assert.That(fields["p19"].ListValue.Values[0].HasStringValue, Is.True); + Assert.That(fields["p19"].ListValue.Values[0].StringValue, Is.EqualTo("test")); + Assert.That(fields["p19"].ListValue.Values[1].HasNullValue, Is.True); + + Assert.That(fields["p20"].KindCase, Is.EqualTo(Value.KindOneofCase.ListValue)); + Assert.That(fields["p20"].ListValue.Values.Count, Is.EqualTo(2)); + Assert.That(fields["p20"].ListValue.Values[0].HasStringValue, Is.True); + Assert.That(fields["p20"].ListValue.Values[0].StringValue, Is.EqualTo("2025-10-02T15:57:31.9990000Z")); + Assert.That(fields["p20"].ListValue.Values[1].HasNullValue, Is.True); + + Assert.That(fields["p21"].KindCase, Is.EqualTo(Value.KindOneofCase.ListValue)); + Assert.That(fields["p21"].ListValue.Values.Count, Is.EqualTo(2)); + Assert.That(fields["p21"].ListValue.Values[0].HasStringValue, Is.True); + Assert.That(fields["p21"].ListValue.Values[0].StringValue, Is.EqualTo("5555990c-b259-4539-bd22-5a9293cf10ac")); + Assert.That(fields["p21"].ListValue.Values[1].HasNullValue, Is.True); + + Assert.That(fields["p22"].KindCase, Is.EqualTo(Value.KindOneofCase.ListValue)); + Assert.That(fields["p22"].ListValue.Values.Count, Is.EqualTo(2)); + Assert.That(fields["p22"].ListValue.Values[0].HasNumberValue, Is.True); + Assert.That(fields["p22"].ListValue.Values[0].NumberValue, Is.EqualTo(3.14d)); + Assert.That(fields["p22"].ListValue.Values[1].HasNullValue, Is.True); + + Assert.That(fields["p23"].KindCase, Is.EqualTo(Value.KindOneofCase.ListValue)); + Assert.That(fields["p23"].ListValue.Values.Count, Is.EqualTo(2)); + Assert.That(fields["p23"].ListValue.Values[0].HasNumberValue, Is.True); + Assert.That(fields["p23"].ListValue.Values[0].NumberValue, Is.EqualTo(3.14f)); + Assert.That(fields["p23"].ListValue.Values[1].HasNullValue, Is.True); + } + + [Test] + public async Task TestExecuteNonQueryWithSelect() + { + const string sql = "select * from my_table"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateSelect1ResultSet()); + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(conn); + cmd.CommandText = sql; + + var result = await cmd.ExecuteNonQueryAsync(); + Assert.That(result, Is.EqualTo(-1)); + } + + [Test] + public async Task TestExecuteNonQueryWithError() + { + const string sql = "select * from my_table"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateException(new RpcException(new Status(StatusCode.NotFound, "Table not found")))); + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(conn); + cmd.CommandText = sql; + + Assert.ThrowsAsync(async () => await cmd.ExecuteNonQueryAsync()); + } + + [Test] + [TestCase(new[] { true }, TestName = "SingleQuery")] + [TestCase(new[] { false }, TestName = "SingleNonQuery")] + [TestCase(new[] { true, true }, TestName = "TwoQueries")] + [TestCase(new[] { false, false }, TestName = "TwoNonQueries")] + [TestCase(new[] { false, true }, TestName = "NonQueryQuery")] + [TestCase(new[] { true, false }, TestName = "QueryNonQuery")] + [Ignore("Requires support for multi-statements strings in the shared library")] + public async Task MultipleStatements(bool[] queries) + { + const string update = "UPDATE my_table SET name='yo' WHERE 1=0;"; + const string select = "SELECT 1;"; + Fixture.SpannerMock.AddOrUpdateStatementResult(update, StatementResult.CreateUpdateCount(0)); + + await using var conn = await OpenConnectionAsync(); + var sb = new StringBuilder(); + foreach (var query in queries) + { + sb.Append(query ? select : update); + } + var sql = sb.ToString(); + foreach (var prepare in new[] { false, true }) + { + await using var cmd = conn.CreateCommand(); + cmd.CommandText = sql; + if (prepare) + { + await cmd.PrepareAsync(); + } + await using var reader = await cmd.ExecuteReaderAsync(); + var numResultSets = queries.Count(q => q); + for (var i = 0; i < numResultSets; i++) + { + Assert.That(await reader.ReadAsync(), Is.True); + Assert.That(reader[0], Is.EqualTo(1)); + Assert.That(await reader.NextResultAsync(), Is.EqualTo(i != numResultSets - 1)); + } + } + } + + [Test] + [Ignore("Requires support for multi-statements strings in the shared library")] + public async Task MultipleStatementsWithParameters([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT @p1; SELECT @p2"; + var p1 = new SpannerParameter{ParameterName = "p1"}; + var p2 = new SpannerParameter{ParameterName = "p2"}; + cmd.Parameters.Add(p1); + cmd.Parameters.Add(p2); + if (prepare == PrepareOrNot.Prepared) + { + await cmd.PrepareAsync(); + } + p1.Value = 8; + p2.Value = "foo"; + await using var reader = await cmd.ExecuteReaderAsync(); + Assert.That(await reader.ReadAsync(), Is.True); + Assert.That(reader.GetInt32(0), Is.EqualTo(8)); + Assert.That(await reader.NextResultAsync(), Is.True); + Assert.That(await reader.ReadAsync(), Is.True); + Assert.That(reader.GetString(0), Is.EqualTo("foo")); + Assert.That(await reader.NextResultAsync(), Is.False); + } + + [Test] + [Ignore("Requires support for multi-statements strings in the shared library")] + public async Task SingleRowMultipleStatements([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand("SELECT 1; SELECT 2", conn); + if (prepare == PrepareOrNot.Prepared) + { + await cmd.PrepareAsync(); + } + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SingleRow); + Assert.That(await reader.ReadAsync(), Is.True); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + Assert.That(await reader.ReadAsync(), Is.False); + Assert.That(await reader.NextResultAsync(), Is.False); + } + + [Test] + [Ignore("Requires support for statement_timeout in the shared library")] + public async Task Timeout() + { + Fixture.SpannerMock.AddOrUpdateExecutionTime(nameof(Fixture.SpannerMock.ExecuteStreamingSql), ExecutionTime.FromMillis(10, 0)); + + await using var dataSource = CreateDataSource(csb => csb.CommandTimeout = 1); + await using var conn = await dataSource.OpenConnectionAsync() as SpannerConnection; + await using var cmd = new SpannerCommand("SELECT 1", conn!); + Assert.That(() => cmd.ExecuteScalar(), Throws.Exception + .TypeOf() + .With.InnerException.TypeOf() + ); + Assert.That(conn!.State, Is.EqualTo(ConnectionState.Open)); + } + + [Test] + [Ignore("Requires support for statement_timeout in the shared library")] + public async Task TimeoutAsync() + { + Fixture.SpannerMock.AddOrUpdateExecutionTime(nameof(Fixture.SpannerMock.ExecuteStreamingSql), ExecutionTime.FromMillis(10, 0)); + + await using var dataSource = CreateDataSource(csb => csb.CommandTimeout = 1); + await using var conn = await dataSource.OpenConnectionAsync() as SpannerConnection; + await using var cmd = new SpannerCommand("SELECT 1", conn!); + Assert.That(async () => await cmd.ExecuteScalarAsync(), + Throws.Exception + .TypeOf() + .With.InnerException.TypeOf()); + Assert.That(conn!.State, Is.EqualTo(ConnectionState.Open)); + } + + [Test] + public async Task TimeoutSwitchConnection() + { + var csb = new SpannerConnectionStringBuilder(ConnectionString); + Assert.That(csb.CommandTimeout, Is.EqualTo(0)); + + await using var dataSource1 = CreateDataSource(ConnectionString + ";CommandTimeout=100"); + await using var c1 = dataSource1.CreateConnection(); + await using var cmd = c1.CreateCommand(); + Assert.That(cmd.CommandTimeout, Is.EqualTo(100)); + await using var dataSource2 = CreateDataSource(ConnectionString + ";CommandTimeout=101"); + await using (var c2 = dataSource2.CreateConnection()) + { + cmd.Connection = c2; + Assert.That(cmd.CommandTimeout, Is.EqualTo(101)); + } + cmd.CommandTimeout = 102; + await using (var c2 = dataSource2.CreateConnection()) + { + cmd.Connection = c2; + Assert.That(cmd.CommandTimeout, Is.EqualTo(102)); + } + } + + [Test] + [Ignore("Requires support for cancel in the shared library")] + public async Task Cancel() + { + var sql = "insert into my_table (id, value) values (1, 'one')"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateUpdateCount(1L)); + Fixture.SpannerMock.AddOrUpdateExecutionTime(nameof(Fixture.SpannerMock.ExecuteStreamingSql), ExecutionTime.FromMillis(50, 0)); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = conn.CreateCommand(); + cmd.CommandText = sql; + + // ReSharper disable once AccessToDisposedClosure + var queryTask = Task.Run(() => cmd.ExecuteNonQuery()); + // Wait until the request is on the mock server. + Fixture.SpannerMock.WaitForRequestsToContain(message => message is ExecuteSqlRequest request && request.Sql == sql); + cmd.Cancel(); + Assert.That(async () => await queryTask, Throws + .TypeOf() + .With.InnerException.TypeOf() + .With.InnerException.Property(nameof(SpannerDbException.ErrorCode)).EqualTo(StatusCode.Cancelled) + ); + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task CloseConnection() + { + await using var conn = await OpenConnectionAsync(); + await using (var cmd = new SpannerCommand("SELECT 1", conn)) + await using (var reader = await cmd.ExecuteReaderAsync(CommandBehavior.CloseConnection)) + { + while (reader.Read()) + { + } + } + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task CloseDuringRead() + { + await using var dataSource = CreateDataSource(); + await using var conn = (await dataSource.OpenConnectionAsync() as SpannerConnection)!; + await using (var cmd = new SpannerCommand("SELECT 1", conn)) + await using (var reader = await cmd.ExecuteReaderAsync()) + { + reader.Read(); + conn.Close(); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + // Closing a SpannerConnection does not close the related readers. + Assert.False(reader.IsClosed); + } + + conn.Open(); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Open)); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + + [Test] + public async Task CloseConnectionWithException() + { + const string sql = "select * from non_existing_table"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateException(new RpcException(new Status(StatusCode.NotFound, "Table not found")))); + + await using var conn = await OpenConnectionAsync(); + await using (var cmd = new SpannerCommand(sql, conn)) + { + Assert.That(() => cmd.ExecuteReaderAsync(CommandBehavior.CloseConnection), + Throws.Exception.TypeOf()); + } + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task SingleRow([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) + { + const string sql = "SELECT 1, 2 UNION SELECT 3, 4"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + new List>([Tuple.Create(TypeCode.Int64, "c1"), Tuple.Create(TypeCode.Int64, "c2")]), + new List([[1L, 2L], [3L, 4L]]))); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(sql, conn); + if (prepare == PrepareOrNot.Prepared) + { + cmd.Prepare(); + } + + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SingleRow); + Assert.That(() => reader.GetInt32(0), Throws.Exception.TypeOf()); + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + Assert.That(reader.Read(), Is.False); + } + + [Test] + public async Task CommandTextNotSet() + { + await using var conn = await OpenConnectionAsync(); + await using (var cmd = new SpannerCommand()) + { + cmd.Connection = conn; + Assert.That(cmd.ExecuteNonQueryAsync, Throws.Exception.TypeOf()); + cmd.CommandText = null; + Assert.That(cmd.ExecuteNonQueryAsync, Throws.Exception.TypeOf()); + cmd.CommandText = ""; + } + + await using (var cmd = conn.CreateCommand()) + { + Assert.That(cmd.ExecuteNonQueryAsync, Throws.Exception.TypeOf()); + } + } + + [Test] + public async Task ExecuteScalar() + { + const string sql = "select name from my_table"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + new List>([Tuple.Create(TypeCode.String, "name")]), + new List([]))); + + await using var conn = await OpenConnectionAsync(); + await using var command = new SpannerCommand(sql, conn); + Assert.That(command.ExecuteScalarAsync, Is.Null); + + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + new List>([Tuple.Create(TypeCode.String, "name")]), + new List([[DBNull.Value]]))); + Assert.That(command.ExecuteScalarAsync, Is.EqualTo(DBNull.Value)); + + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + new List>([Tuple.Create(TypeCode.String, "name")]), + new List([["X1"], ["X2"]]))); + Assert.That(command.ExecuteScalarAsync, Is.EqualTo("X1")); + } + + [Test] + public async Task ExecuteNonQuery() + { + const string insertOneRow = "insert into my_table (name) values ('Test')"; + Fixture.SpannerMock.AddOrUpdateStatementResult(insertOneRow, StatementResult.CreateUpdateCount(1L)); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = conn.CreateCommand(); + + // Insert one row + cmd.CommandText = insertOneRow; + Assert.That(cmd.ExecuteNonQueryAsync, Is.EqualTo(1)); + + // Insert two rows in one batch using a SQL string that contains two statements. + // TODO: Enable when SpannerLib supports SQL strings with multiple statements. + // cmd.CommandText = $"{insertOneRow}; {insertOneRow}"; + // Assert.That(cmd.ExecuteNonQueryAsync, Is.EqualTo(2)); + + // Execute a large SQL string. + var value = TestUtils.GenerateRandomString(10_000_000); + cmd.CommandText = $"insert into my_table (name) values ('{value}')"; + Fixture.SpannerMock.AddOrUpdateStatementResult(cmd.CommandText, StatementResult.CreateUpdateCount(1L)); + Assert.That(cmd.ExecuteNonQueryAsync, Is.EqualTo(1)); + } + + [Test] + public async Task Dispose() + { + await using var conn = await OpenConnectionAsync(); + var cmd = new SpannerCommand("SELECT 1", conn); + cmd.Dispose(); + Assert.That(() => cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf()); + Assert.That(() => cmd.ExecuteNonQueryAsync(), Throws.Exception.TypeOf()); + Assert.That(() => cmd.ExecuteReaderAsync(), Throws.Exception.TypeOf()); + Assert.That(() => cmd.PrepareAsync(), Throws.Exception.TypeOf()); + } + + [Test] + public async Task DisposeDesNotCloseReader() + { + await using var conn = await OpenConnectionAsync(); + var cmd = new SpannerCommand("SELECT 1", conn); + await using var reader1 = await cmd.ExecuteReaderAsync(); + cmd.Dispose(); + cmd = new SpannerCommand("SELECT 1", conn); + await using var reader2 = await cmd.ExecuteReaderAsync(); + Assert.That(reader2, Is.Not.Null); + Assert.That(reader1.IsClosed, Is.False); + Assert.That(await reader1.ReadAsync(), Is.True); + } + + [Test] + [TestCase(CommandBehavior.Default)] + [TestCase(CommandBehavior.SequentialAccess)] + public async Task StatementMappedOutputParameters(CommandBehavior behavior) + { + const string sql = "select 3, 4 as param1, 5 as param2, 6"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + new List>([ + Tuple.Create(TypeCode.Int64, "c1"), + Tuple.Create(TypeCode.Int64, "param1"), + Tuple.Create(TypeCode.Int64, "param2"), + Tuple.Create(TypeCode.Int64, "c2")]), + new List([[3, 4, 5, 6]]))); + + await using var conn = await OpenConnectionAsync(); + var command = new SpannerCommand(sql, conn); + + var p = new SpannerParameter + { + ParameterName = "param2", + Direction = ParameterDirection.Output, + Value = -1, + DbType = DbType.Int64, + }; + command.Parameters.Add(p); + + p = new SpannerParameter + { + ParameterName = "param1", + Direction = ParameterDirection.Output, + Value = -1, + DbType = DbType.Int64, + }; + command.Parameters.Add(p); + + p = new SpannerParameter + { + ParameterName = "p", + Direction = ParameterDirection.Output, + Value = -1, + DbType = DbType.Int64, + }; + command.Parameters.Add(p); + + await using var reader = await command.ExecuteReaderAsync(behavior); + + // TODO: Enable if we decide to support output parameters in the same way as npgsql. + // Assert.That(command.Parameters["param1"].Value, Is.EqualTo(4)); + // Assert.That(command.Parameters["param2"].Value, Is.EqualTo(5)); + + await reader.ReadAsync(); + + Assert.That(reader.GetInt32(0), Is.EqualTo(3)); + Assert.That(reader.GetInt32(1), Is.EqualTo(4)); + Assert.That(reader.GetInt32(2), Is.EqualTo(5)); + Assert.That(reader.GetInt32(3), Is.EqualTo(6)); + } + + [Test] + public async Task TableDirect() + { + const string sql = "select * from my_table"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + new List>([Tuple.Create(TypeCode.String, "name")]), + new List([["foo"]]))); + await using var conn = await OpenConnectionAsync(); + + await using var cmd = new SpannerCommand("my_table", conn) { CommandType = CommandType.TableDirect }; + await using var reader = await cmd.ExecuteReaderAsync(); + Assert.That(await reader.ReadAsync(), Is.True); + Assert.That(reader["name"], Is.EqualTo("foo")); + } + + [Test] + public async Task InvalidUtf8() + { + const string sql = "SELECT 'abc\uD801\uD802d'"; + Fixture.SpannerMock.AddOrUpdateStatementResult("SELECT 'abc��d'", StatementResult.CreateResultSet( + new List>([Tuple.Create(TypeCode.String, "c")]), + new List([["abc��d"]]))); + + await using var dataSource = CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync() as SpannerConnection; + var value = await conn!.ExecuteScalarAsync(sql); + Assert.That(value, Is.EqualTo("abc��d")); + } + + [Test] + public async Task UseAcrossConnectionChange([Values(PrepareOrNot.Prepared, PrepareOrNot.NotPrepared)] PrepareOrNot prepare) + { + await using var conn1 = await OpenConnectionAsync(); + await using var conn2 = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand("SELECT 1", conn1); + if (prepare == PrepareOrNot.Prepared) + { + await cmd.PrepareAsync(); + } + cmd.Connection = conn2; + if (prepare == PrepareOrNot.Prepared) + { + await cmd.PrepareAsync(); + } + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(1)); + } + + [Test] + public async Task CreateCommandBeforeConnectionOpen() + { + await using var conn = new SpannerConnection(ConnectionString); + var cmd = new SpannerCommand("SELECT 1", conn); + conn.Open(); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(1)); + } + + [Test] + public void ConnectionNotSetThrows() + { + var cmd = new SpannerCommand { CommandText = "SELECT 1" }; + Assert.That(() => cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf()); + } + + [Test] + public void ConnectionNotOpenTrows() + { + using var conn = new SpannerConnection(ConnectionString); + var cmd = new SpannerCommand("SELECT 1", conn); + Assert.That(() => cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf()); + } + + [Test] + public async Task ExecuteNonQueryThrowsSpannerDbException([Values] bool async) + { + const string sql = "insert into my_table (ref) values (1) returning ref"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateException(new RpcException(new Status(StatusCode.FailedPrecondition, "Foreign key constraint violation")))); + await using var conn = await OpenConnectionAsync(); + + var ex = async + ? Assert.ThrowsAsync(async () => await conn.ExecuteNonQueryAsync(sql)) + : Assert.Throws(() => conn.ExecuteNonQuery(sql)); + Assert.That(ex!.Status.Code, Is.EqualTo((int) StatusCode.FailedPrecondition)); + } + + [Test] + public async Task ExecuteScalarThrowsSpannerDbException([Values] bool async) + { + const string sql = "insert into my_table (ref) values (1) returning ref"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateException(new RpcException(new Status(StatusCode.FailedPrecondition, "Foreign key constraint violation")))); + await using var conn = await OpenConnectionAsync(); + + var ex = async + ? Assert.ThrowsAsync(async () => await conn.ExecuteScalarAsync(sql)) + : Assert.Throws(() => conn.ExecuteScalar(sql)); + Assert.That(ex!.Status.Code, Is.EqualTo((int) StatusCode.FailedPrecondition)); + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task ExecuteReaderThrowsSpannerDbException([Values] bool async) + { + const string sql = "insert into my_table (ref) values (1) returning ref"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateException(new RpcException(new Status(StatusCode.FailedPrecondition, "Foreign key constraint violation")))); + await using var conn = await OpenConnectionAsync(); + await using var cmd = conn.CreateCommand(); + cmd.CommandText = sql; + + var ex = async + ? Assert.ThrowsAsync(async () => await cmd.ExecuteReaderAsync()) + : Assert.Throws(() => cmd.ExecuteReader()); + Assert.That(ex!.Status.Code, Is.EqualTo((int) StatusCode.FailedPrecondition)); + } + + [Test] + public void CommandIsNotRecycled() + { + const string sql = "select @p1"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.Int64}, "p1", 8L)); + + using var conn = new SpannerConnection(ConnectionString); + var cmd1 = conn.CreateCommand(); + cmd1.CommandText = sql; + var tx = conn.BeginTransaction(); + cmd1.Transaction = tx; + AddParameter(cmd1, "p1", 8); + _ = cmd1.ExecuteScalar(); + cmd1.Dispose(); + + var cmd2 = conn.CreateCommand(); + Assert.That(cmd2, Is.Not.SameAs(cmd1)); + Assert.That(cmd2.CommandText, Is.Empty); + Assert.That(cmd2.CommandType, Is.EqualTo(CommandType.Text)); + Assert.That(cmd2.Transaction, Is.Null); + Assert.That(cmd2.Parameters, Is.Empty); + } + + [Test] + public async Task ManyParameters([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(conn); + var sb = new StringBuilder($"INSERT INTO my_table (some_column) VALUES "); + var numParams = ushort.MaxValue; + for (var i = 0; i < numParams; i++) + { + var paramName = "p" + i; + AddParameter(cmd, paramName, i); + if (i > 0) + sb.Append(", "); + sb.Append($"(@{paramName})"); + } + cmd.CommandText = sb.ToString(); + Fixture.SpannerMock.AddOrUpdateStatementResult(cmd.CommandText, StatementResult.CreateUpdateCount(numParams)); + + if (prepare == PrepareOrNot.Prepared) + { + await cmd.PrepareAsync(); + } + await cmd.ExecuteNonQueryAsync(); + } + + [Test] + [Ignore("Requires multi-statement support in SpannerLib")] + public async Task ManyParametersAcrossStatements() + { + var result = StatementResult.CreateSelect1ResultSet(); + // Create a command with 1000 statements which have 70 params each + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(conn); + var paramIndex = 0; + var sb = new StringBuilder(); + for (var statementIndex = 0; statementIndex < 1000; statementIndex++) + { + if (statementIndex > 0) + sb.Append("; "); + var statement = new StringBuilder(); + statement.Append("SELECT "); + var startIndex = paramIndex; + var endIndex = paramIndex + 70; + for (; paramIndex < endIndex; paramIndex++) + { + var paramName = "p" + paramIndex; + AddParameter(cmd, paramName, paramIndex); + if (paramIndex > startIndex) + statement.Append(", "); + statement.Append('@'); + statement.Append(paramName); + } + sb.Append(statement); + Fixture.SpannerMock.AddOrUpdateStatementResult(statement.ToString(), result); + } + cmd.CommandText = sb.ToString(); + await cmd.ExecuteNonQueryAsync(); + } + + [Test] + public async Task SameCommandDifferentParamValues() + { + const string sql = "select @p"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.Int64}, "p", 8L)); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(sql, conn); + AddParameter(cmd, "p", 8); + await cmd.ExecuteNonQueryAsync(); + var request = Fixture.SpannerMock.Requests.First(r => r is ExecuteSqlRequest { Sql: sql }) as ExecuteSqlRequest; + Assert.That(request, Is.Not.Null); + Assert.That(request.Params.Fields["p"].StringValue, Is.EqualTo("8")); + + Fixture.SpannerMock.ClearRequests(); + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.Int64}, "p", 9L)); + + cmd.Parameters[0].Value = 9; + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(9)); + request = Fixture.SpannerMock.Requests.First(r => r is ExecuteSqlRequest { Sql: sql }) as ExecuteSqlRequest; + Assert.That(request, Is.Not.Null); + Assert.That(request.Params.Fields["p"].StringValue, Is.EqualTo("9")); + } + + [Test] + public async Task SameCommandDifferentParamInstances() + { + const string sql = "select @p"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.Int64}, "p", 8L)); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(sql, conn); + AddParameter(cmd, "p", 8); + await cmd.ExecuteNonQueryAsync(); + var request = Fixture.SpannerMock.Requests.First(r => r is ExecuteSqlRequest { Sql: sql }) as ExecuteSqlRequest; + Assert.That(request, Is.Not.Null); + Assert.That(request.Params.Fields["p"].StringValue, Is.EqualTo("8")); + + Fixture.SpannerMock.ClearRequests(); + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.Int64}, "p", 9L)); + + cmd.Parameters.RemoveAt(0); + AddParameter(cmd, "p", 9); + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(9)); + request = Fixture.SpannerMock.Requests.First(r => r is ExecuteSqlRequest { Sql: sql }) as ExecuteSqlRequest; + Assert.That(request, Is.Not.Null); + Assert.That(request.Params.Fields["p"].StringValue, Is.EqualTo("9")); + } + + [Test] + public async Task CancelWhileReadingFromLongRunningQuery() + { + const string sql = "select id from my_table"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateSingleColumnResultSet( + new V1.Type{Code = TypeCode.Int64}, + "id", + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20])); + await using var conn = await OpenConnectionAsync(); + + await using var cmd = conn.CreateCommand(); + cmd.CommandText = sql; + + using (var cts = new CancellationTokenSource()) + await using (var reader = await cmd.ExecuteReaderAsync(cts.Token)) + { + Assert.ThrowsAsync(async () => + { + var i = 0; + while (await reader.ReadAsync(cts.Token)) + { + i++; + if (i == 10) + { + await cts.CancelAsync(); + } + } + }); + } + + cmd.CommandText = "SELECT 1"; + Assert.That(await cmd.ExecuteScalarAsync(CancellationToken.None), Is.EqualTo(1)); + } + + [Test] + public async Task CompletedTransactionThrows([Values] bool commit) + { + await using var conn = await OpenConnectionAsync(); + await using var tx = await conn.BeginTransactionAsync(); + await using var cmd = conn.CreateCommand(); + + if (commit) + { + await tx.CommitAsync(); + } + else + { + await tx.RollbackAsync(); + } + Assert.Throws(() => cmd.Transaction = tx); + } + + private void AddParameter(DbCommand command, string name, object? value) + { + var parameter = command.CreateParameter(); + parameter.ParameterName = name; + parameter.Value = value; + command.Parameters.Add(parameter); + } +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-tests/ConnectionStringBuilderTests.cs b/drivers/spanner-ado-net/spanner-ado-net-tests/ConnectionStringBuilderTests.cs new file mode 100644 index 00000000..d413edd8 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-tests/ConnectionStringBuilderTests.cs @@ -0,0 +1,171 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Data; + +namespace Google.Cloud.Spanner.DataProvider.Tests; + +public class ConnectionStringBuilderTests +{ + [Test] + public void Basic() + { + var builder = new SpannerConnectionStringBuilder(); + Assert.That(builder.Keys, Is.Empty); + Assert.That(builder.Count, Is.EqualTo(0)); + Assert.False(builder.ContainsKey("host")); + builder.Host = "myhost"; + Assert.That(builder["host"], Is.EqualTo("myhost")); + Assert.That(builder.Count, Is.EqualTo(1)); + Assert.That(builder.ConnectionString, Is.EqualTo("Host=myhost")); + builder.Remove("HOST"); + Assert.That(builder["host"], Is.EqualTo("")); + Assert.That(builder.Count, Is.EqualTo(0)); + } + + [Test] + public void TryGetValue() + { + var builder = new SpannerConnectionStringBuilder + { + ConnectionString = "Host=myhost" + }; + Assert.That(builder.TryGetValue("Host", out var value), Is.True); + Assert.That(value, Is.EqualTo("myhost")); + Assert.That(builder.TryGetValue("SomethingUnknown", out value), Is.False); + } + + [Test] + public void Remove() + { + var builder = new SpannerConnectionStringBuilder + { + UsePlainText = true + }; + Assert.That(builder["Use plain text"], Is.True); + builder.Remove("UsePlainText"); + Assert.That(builder.ConnectionString, Is.EqualTo("")); + } + + [Test] + public void Clear() + { + var builder = new SpannerConnectionStringBuilder { Host = "myhost" }; + builder.Clear(); + Assert.That(builder.Count, Is.EqualTo(0)); + Assert.That(builder["host"], Is.EqualTo("")); + Assert.That(builder.Host, Is.Empty); + } + + [Test] + public void RemovingResetsToDefault() + { + var builder = new SpannerConnectionStringBuilder(); + Assert.That(builder.Port, Is.EqualTo(SpannerConnectionStringOption.Port.DefaultValue)); + builder.Port = 8; + builder.Remove("Port"); + Assert.That(builder.Port, Is.EqualTo(SpannerConnectionStringOption.Port.DefaultValue)); + } + + [Test] + public void SettingToNullResetsToDefault() + { + var builder = new SpannerConnectionStringBuilder(); + Assert.That(builder.Port, Is.EqualTo(SpannerConnectionStringOption.Port.DefaultValue)); + builder.Port = 8; + builder["Port"] = null; + Assert.That(builder.Port, Is.EqualTo(SpannerConnectionStringOption.Port.DefaultValue)); + } + + [Test] + public void Enum() + { + var builder = new SpannerConnectionStringBuilder + { + ConnectionString = "DefaultIsolationLevel=Serializable" + }; + Assert.That(builder.DefaultIsolationLevel, Is.EqualTo(IsolationLevel.Serializable)); + Assert.That(builder.Count, Is.EqualTo(1)); + } + + [Test] + public void EnumCaseInsensitive() + { + var builder = new SpannerConnectionStringBuilder + { + ConnectionString = "defaultisolationlevel=repeatable read" + }; + Assert.That(builder.DefaultIsolationLevel, Is.EqualTo(IsolationLevel.RepeatableRead)); + Assert.That(builder.Count, Is.EqualTo(1)); + } + + [Test] + public void Clone() + { + var builder = new SpannerConnectionStringBuilder + { + Host = "myhost" + }; + var builder2 = builder.Clone(); + Assert.That(builder2.Host, Is.EqualTo("myhost")); + Assert.That(builder2["Host"], Is.EqualTo("myhost")); + Assert.That(builder.Port, Is.EqualTo(SpannerConnectionStringOption.Port.DefaultValue)); + } + + [Test] + public void ConversionErrorThrows() + { + // ReSharper disable once CollectionNeverQueried.Local + var builder = new SpannerConnectionStringBuilder(); + Assert.That(() => builder["Port"] = "hello", + Throws.Exception.TypeOf().With.Message.Contains("Port")); + } + + [Test] + public void InvalidConnectionStringThrows() + { + var builder = new SpannerConnectionStringBuilder(); + Assert.That(() => builder.ConnectionString = "Server=127.0.0.1;User Id=npgsql_tests;Pooling:false", + Throws.Exception.TypeOf()); + } + + [Test] + public void ConnectionStringToProperties() + { + var builder = new SpannerConnectionStringBuilder + { + ConnectionString = "Host=localhost;Port=80;UsePlainText=true;DefaultIsolationLevel=Repeatable read", + }; + Assert.That(builder.Host, Is.EqualTo("localhost")); + Assert.That(builder.Port, Is.EqualTo(80)); + Assert.That(builder.UsePlainText, Is.True); + Assert.That(builder.DefaultIsolationLevel, Is.EqualTo(IsolationLevel.RepeatableRead)); + } + + [Test] + public void PropertiesToConnectionString() + { + var builder = new SpannerConnectionStringBuilder + { + Host = "localhost", + Port = 80, + UsePlainText = true, + DefaultIsolationLevel = IsolationLevel.RepeatableRead, + DataSource = "projects/project1/instances/instance1/databases/database1" + }; + Assert.That(builder.ConnectionString, Is.EqualTo("Data Source=projects/project1/instances/instance1/databases/database1;Host=localhost;Port=80;UsePlainText=True;DefaultIsolationLevel=RepeatableRead")); + Assert.That(builder.SpannerLibConnectionString, Is.EqualTo("localhost:80/projects/project1/instances/instance1/databases/database1;UsePlainText=True;DefaultIsolationLevel=RepeatableRead")); + } + +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-tests/ConnectionTests.cs b/drivers/spanner-ado-net/spanner-ado-net-tests/ConnectionTests.cs new file mode 100644 index 00000000..dddbc070 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-tests/ConnectionTests.cs @@ -0,0 +1,559 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Data; +using System.Diagnostics.CodeAnalysis; +using Google.Cloud.Spanner.V1; +using Google.Cloud.SpannerLib; +using Google.Cloud.SpannerLib.MockServer; +using Google.Rpc; +using Grpc.Core; +using Status = Grpc.Core.Status; +using TypeCode = Google.Cloud.Spanner.V1.TypeCode; + +namespace Google.Cloud.Spanner.DataProvider.Tests; + +public class ConnectionTests : AbstractMockServerTests +{ + [Test] + public void TestOpenConnection() + { + var connection = new SpannerConnection { ConnectionString = ConnectionString }; + connection.Open(); + connection.Close(); + } + + [Test] + public void TestExecute() + { + var sql = "update all_types set col_float8=1 where col_bigint=1"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateUpdateCount(1)); + using var connection = new SpannerConnection(); + connection.ConnectionString = ConnectionString; + connection.Open(); + + var command = connection.CreateCommand(); + command.CommandText = sql; + var updateCount = command.ExecuteNonQuery(); + Assert.That(updateCount, Is.EqualTo(1)); + } + + [Test] + public void TestQuery() + { + var sql = "select col_varchar from all_types where col_varchar is not null limit 10"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = TypeCode.String}, "col_varchar", "value1", "value2", "value3")); + using var connection = new SpannerConnection(); + connection.ConnectionString = ConnectionString; + connection.Open(); + + using var command = connection.CreateCommand(); + command.CommandText = sql; + var rowCount = 0; + using (var reader = command.ExecuteReader()) + { + while (reader.Read()) + { + rowCount++; + Assert.That(reader.GetString(0), Is.EqualTo($"value{rowCount}")); + } + } + Assert.That(rowCount, Is.EqualTo(3)); + } + + [Test] + public void TestParameterizedQuery() + { + var sql = "select * from all_types where col_varchar=@p1"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateSelect1ResultSet()); + using var connection = new SpannerConnection(); + connection.ConnectionString = ConnectionString; + connection.Open(); + + using var command = connection.CreateCommand(); + command.CommandText = sql; + command.Parameters.Add("2de7b24590e00a58fa7358c9531301c5"); + using (var reader = command.ExecuteReader()) + { + for (int i = 0; i < reader.FieldCount; i++) + { + Assert.That(reader.GetFieldType(i), Is.Not.Null); + } + while (reader.Read()) + { + for (int i = 0; i < reader.FieldCount; i++) + { + Assert.That(reader.GetValue(i), Is.Not.Null); + } + } + } + var requests = Fixture.SpannerMock.Requests.OfType().ToList(); + Assert.That(requests, Has.Count.EqualTo(1)); + var request = requests.First(); + Assert.That(request.Params.Fields, Has.Count.EqualTo(1)); + } + + [Test] + public void TestTransaction() + { + var sql = "select * from all_types where col_varchar=$1"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateSelect1ResultSet()); + using var connection = new SpannerConnection(); + connection.ConnectionString = ConnectionString; + connection.Open(); + + using var transaction = connection.BeginTransaction(); + using var command = connection.CreateCommand(); + command.Transaction = transaction; + command.CommandText = sql; + command.Parameters.Add("2de7b24590e00a58fa7358c9531301c5"); + using (var reader = command.ExecuteReader()) + { + for (int i = 0; i < reader.FieldCount; i++) + { + Assert.That(reader.GetFieldType(i), Is.Not.Null); + } + while (reader.Read()) + { + for (int i = 0; i < reader.FieldCount; i++) + { + Assert.That(reader.GetValue(i), Is.Not.Null); + } + } + } + transaction.Commit(); + } + + [Test] + public void TestDisableInternalRetries() + { + var sql = "update my_table set value=@p1 where id=@p2 and version=1"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateUpdateCount(1)); + using var connection = new SpannerConnection(); + connection.ConnectionString = ConnectionString + ";retryAbortsInternally=false"; + connection.Open(); + + using var transaction = connection.BeginTransaction(); + using var command = connection.CreateCommand(); + command.Transaction = transaction; + command.CommandText = sql; + command.Parameters.Add("2de7b24590e00a58fa7358c9531301c5"); + command.Parameters.Add(1L); + command.ExecuteNonQuery(); + Fixture.SpannerMock.AddOrUpdateExecutionTime(nameof(Fixture.SpannerMock.Commit), ExecutionTime.CreateException(StatusCode.Aborted, "Transaction was aborted")); + Assert.Throws(transaction.Commit); + + var requests = Fixture.SpannerMock.Requests.OfType().ToList(); + Assert.That(requests.Count, Is.EqualTo(1)); + } + + [Test] + public void TestBatchDml() + { + var sql1 = "update all_types set col_float8=1 where col_bigint=1"; + var sql2 = "update all_types set col_float8=2 where col_bigint=2"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql1, StatementResult.CreateUpdateCount(2)); + Fixture.SpannerMock.AddOrUpdateStatementResult(sql2, StatementResult.CreateUpdateCount(3)); + using var connection = new SpannerConnection(); + connection.ConnectionString = ConnectionString; + connection.Open(); + + using var command1 = connection.CreateCommand(); + command1.CommandText = sql1; + using var command2 = connection.CreateCommand(); + command2.CommandText = sql2; + var affected = connection.ExecuteBatchDml([command1, command2]); + Assert.That(affected, Is.EqualTo(new long[] { 2, 3 })); + } + + [Test] + public async Task TestBasicLifecycle() + { + await using var conn = new SpannerConnection(); + conn.ConnectionString = ConnectionString; + + var eventConnecting = false; + var eventOpen = false; + var eventClosed = false; + + conn.StateChange += (_, e) => + { + if (e is { OriginalState: ConnectionState.Closed, CurrentState: ConnectionState.Connecting }) + eventConnecting = true; + + if (e is { OriginalState: ConnectionState.Connecting, CurrentState: ConnectionState.Open }) + eventOpen = true; + + if (e is { OriginalState: ConnectionState.Open, CurrentState: ConnectionState.Closed }) + eventClosed = true; + }; + + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + Assert.That(eventConnecting, Is.False); + Assert.That(eventOpen, Is.False); + + await conn.OpenAsync(); + + Assert.That(conn.State, Is.EqualTo(ConnectionState.Open)); + Assert.That(eventConnecting, Is.True); + Assert.That(eventOpen, Is.True); + + await using (var cmd = new SpannerCommand("SELECT 1", conn)) + await using (var reader = await cmd.ExecuteReaderAsync()) + { + await reader.ReadAsync(); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Open)); + } + + Assert.That(conn.State, Is.EqualTo(ConnectionState.Open)); + + await conn.CloseAsync(); + + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + Assert.That(eventClosed, Is.True); + } + + [Test] + [Ignore("SpannerLib should support a connect_timeout property to make this test quicker")] + public async Task TestInvalidHost() + { + await using var conn = new SpannerConnection(); + conn.ConnectionString = $"{Fixture.Host}_invalid:{Fixture.Port}/projects/p1/instances/i1/databases/d1;UsePlainText=true"; + var exception = Assert.Throws(() => conn.Open()); + Assert.That(exception.Code, Is.EqualTo(Code.DeadlineExceeded)); + } + + [Test] + public async Task TestInvalidDatabase() + { + // Close all current pools to ensure that we get a fresh pool. + SpannerPool.CloseSpannerLib(); + // TODO: Make this a public property in the mock server. + const string detectDialectQuery = + "select option_value from information_schema.database_options where option_name='database_dialect'"; + Fixture.SpannerMock.AddOrUpdateStatementResult(detectDialectQuery, StatementResult.CreateException(new RpcException(new Status(StatusCode.NotFound, "Database not found")))); + await using var conn = new SpannerConnection(); + conn.ConnectionString = ConnectionString; + var exception = Assert.Throws(() => conn.Open()); + Assert.That(exception.Code, Is.EqualTo(Code.NotFound)); + } + + [Test] + public void TestConnectWithConnectionStringBuilder() + { + var builder = new SpannerConnectionStringBuilder + { + DataSource = "projects/my-project/instances/my-instance/databases/my-database", + Host = Fixture.Host, + Port = (uint) Fixture.Port, + UsePlainText = true + }; + using var connection = new SpannerConnection(builder); + Assert.That(connection.ConnectionString, Is.EqualTo(builder.ConnectionString)); + } + + [Test] + public void RequiredConnectionStringProperties() + { + using var connection = new SpannerConnection(); + Assert.Throws(() => connection.ConnectionString = "Host=localhost;Port=80"); + } + + [Test] + public void FailedConnectThenSucceed() + { + // Close all current pools to ensure that we get a fresh pool. + SpannerPool.CloseSpannerLib(); + // TODO: Make this a public property in the mock server. + const string detectDialectQuery = + "select option_value from information_schema.database_options where option_name='database_dialect'"; + Fixture.SpannerMock.AddOrUpdateStatementResult(detectDialectQuery, StatementResult.CreateException(new RpcException(new Status(StatusCode.NotFound, "Database not found")))); + using var conn = new SpannerConnection(); + conn.ConnectionString = ConnectionString; + var exception = Assert.Throws(() => conn.Open()); + Assert.That(exception.Code, Is.EqualTo(Code.NotFound)); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + + // Remove the error and retry. + Fixture.SpannerMock.AddOrUpdateStatementResult(detectDialectQuery, StatementResult.CreateResultSet(new List> + { + Tuple.Create(TypeCode.String, "option_value") + }, new List + { + new object[] { "GOOGLE_STANDARD_SQL" } + })); + conn.Open(); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Open)); + } + + [Test] + [Ignore("Needs connect_timeout property")] + public void OpenTimeout() + { + // Close all current pools to ensure that we get a fresh pool. + SpannerPool.CloseSpannerLib(); + Fixture.SpannerMock.AddOrUpdateExecutionTime(nameof(Fixture.SpannerMock.CreateSession), ExecutionTime.FromMillis(20, 0)); + var builder = new SpannerConnectionStringBuilder + { + Host = Fixture.Host, + Port = (uint) Fixture.Port, + UsePlainText = true, + DataSource = "projects/project1/instances/instance1/databases/database1", + ConnectionTimeout = 1, + }; + using var connection = new SpannerConnection(); + connection.ConnectionString = builder.ConnectionString; + var exception = Assert.Throws(() => connection.Open()); + Assert.That(exception.ErrorCode, Is.EqualTo((int) Code.DeadlineExceeded)); + } + + [Test] + [Ignore("OpenAsync must be implemented")] + public async Task OpenCancel() + { + // Close all current pools to ensure that we get a fresh pool. + SpannerPool.CloseSpannerLib(); + Fixture.SpannerMock.AddOrUpdateExecutionTime(nameof(Fixture.SpannerMock.CreateSession), ExecutionTime.FromMillis(20, 0)); + var builder = new SpannerConnectionStringBuilder + { + Host = Fixture.Host, + Port = (uint) Fixture.Port, + UsePlainText = true, + DataSource = "projects/project1/instances/instance1/databases/database1", + }; + await using var connection = new SpannerConnection(); + connection.ConnectionString = builder.ConnectionString; + var tokenSource = new CancellationTokenSource(5); + // TODO: Implement actual async opening of connections + Assert.ThrowsAsync(async () => await connection.OpenAsync(tokenSource.Token)); + Assert.That(connection.State, Is.EqualTo(ConnectionState.Closed)); + } + + [Test] + public void DataSourceProperty() + { + using var conn = new SpannerConnection(); + Assert.That(conn.DataSource, Is.EqualTo(string.Empty)); + + var builder = new SpannerConnectionStringBuilder(ConnectionString); + + conn.ConnectionString = builder.ConnectionString; + Assert.That(conn.DataSource, Is.EqualTo("projects/p1/instances/i1/databases/d1")); + } + + [Test] + public void SettingConnectionStringWhileOpenThrows() + { + using var conn = new SpannerConnection(); + conn.ConnectionString = ConnectionString; + conn.Open(); + Assert.That(() => conn.ConnectionString = "", Throws.Exception.TypeOf()); + } + + [Test] + public void EmptyConstructor() + { + var conn = new SpannerConnection(); + Assert.That(conn.ConnectionTimeout, Is.EqualTo(15)); + Assert.That(conn.ConnectionString, Is.SameAs(string.Empty)); + Assert.That(() => conn.Open(), Throws.Exception.TypeOf()); + } + + [Test] + public void ConstructorWithNullConnectionString() + { + var conn = new SpannerConnection((string?) null); + Assert.That(conn.ConnectionString, Is.SameAs(string.Empty)); + Assert.That(() => conn.Open(), Throws.Exception.TypeOf()); + } + + [Test] + public void ConstructorWithEmptyConnectionString() + { + var conn = new SpannerConnection(""); + Assert.That(conn.ConnectionString, Is.SameAs(string.Empty)); + Assert.That(() => conn.Open(), Throws.Exception.TypeOf()); + } + + [Test] + public void SetConnectionStringToNull() + { + var conn = new SpannerConnection(ConnectionString); + conn.ConnectionString = null; + Assert.That(conn.ConnectionString, Is.SameAs(string.Empty)); + Assert.That(() => conn.Open(), Throws.Exception.TypeOf()); + } + + [Test] + public void SetConnectionStringToEmpty() + { + var conn = new SpannerConnection(ConnectionString); + conn.ConnectionString = ""; + Assert.That(conn.ConnectionString, Is.SameAs(string.Empty)); + Assert.That(() => conn.Open(), Throws.Exception.TypeOf()); + } + + [Test] + public async Task ChangeDatabase() + { + await using var conn = await OpenConnectionAsync(); + Assert.That(conn.Database, Is.EqualTo("projects/p1/instances/i1/databases/d1")); + conn.ChangeDatabase("template1"); + Assert.That(conn.Database, Is.EqualTo("projects/p1/instances/i1/databases/template1")); + } + + [Test] + public async Task ChangeDatabaseDoesNotAffectOtherConnections() + { + await using var conn1 = new SpannerConnection(ConnectionString); + await using var conn2 = new SpannerConnection(ConnectionString); + conn1.Open(); + conn1.ChangeDatabase("template1"); + Assert.That(conn1.Database, Is.EqualTo("projects/p1/instances/i1/databases/template1")); + + // Connection 2's database should not changed + conn2.Open(); + Assert.That(conn2.Database, Is.EqualTo("projects/p1/instances/i1/databases/d1")); + } + + [Test] + public void ChangeDatabaseOnClosedConnectionWorks() + { + using var conn = new SpannerConnection(ConnectionString); + Assert.That(conn.Database, Is.EqualTo("projects/p1/instances/i1/databases/d1")); + conn.ChangeDatabase("template1"); + Assert.That(conn.Database, Is.EqualTo("projects/p1/instances/i1/databases/template1")); + } + + [Test] + [Ignore("Must add search_path connection property in shared library first")] + public async Task SearchPath() + { + // TODO: Add search_path connection variable in shared library + await using var dataSource = CreateDataSource(csb => csb.SearchPath = "foo"); + await using var conn = await dataSource.OpenConnectionAsync() as SpannerConnection; + Assert.That(await conn!.ExecuteScalarAsync("SHOW VARIABLE search_path"), Contains.Substring("foo")); + } + + [Test] + public async Task SetOptions() + { + await using var dataSource = CreateDataSource(csb => csb.Options = "isolation_level=serializable;read_lock_mode=pessimistic"); + await using var conn = await dataSource.OpenConnectionAsync() as SpannerConnection; + + Assert.That(await conn!.ExecuteScalarAsync("SHOW VARIABLE isolation_level"), Is.EqualTo("Serializable")); + Assert.That(await conn!.ExecuteScalarAsync("SHOW VARIABLE read_lock_mode"), Is.EqualTo("PESSIMISTIC")); + } + + [Test] + public async Task ConnectorNotInitializedException() + { + var command = new SpannerCommand(); + command.CommandText = "SELECT 1"; + + for (var i = 0; i < 2; i++) + { + await using var connection = await OpenConnectionAsync(); + command.Connection = connection; + await using var tx = await connection.BeginTransactionAsync(); + await command.ExecuteScalarAsync(); + await tx.CommitAsync(); + } + } + + [Test] + public void ConnectionStateIsClosedWhenDisposed() + { + var c = new SpannerConnection(); + c.Dispose(); + Assert.That(c.State, Is.EqualTo(ConnectionState.Closed)); + } + + [Test] + public async Task ConcurrentReadersAllowed() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand("SELECT 1", conn); + await using (await cmd.ExecuteReaderAsync()) + { + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + } + + [Test] + public async Task ManyOpenClose() + { + await using var dataSource = CreateDataSource(); + for (var i = 0; i < 256; i++) + { + await using var conn = await dataSource.OpenConnectionAsync(); + } + await using (var conn = dataSource.CreateConnection()) + { + await conn.OpenAsync(); + } + await using (var conn = dataSource.CreateConnection() as SpannerConnection) + { + await conn!.OpenAsync(); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + } + + [Test] + public async Task ManyOpenCloseWithTransaction() + { + await using var dataSource = CreateDataSource(); + for (var i = 0; i < 256; i++) + { + await using var conn = await dataSource.OpenConnectionAsync(); + await conn.BeginTransactionAsync(); + } + + await using (var conn = await dataSource.OpenConnectionAsync() as SpannerConnection) + { + Assert.That(await conn!.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + } + + [Test] + public async Task RollbackOnClose() + { + await using var dataSource = CreateDataSource(); + await using (var conn = await dataSource.OpenConnectionAsync() as SpannerConnection) + { + await conn!.BeginTransactionAsync(); + await conn.ExecuteNonQueryAsync("SELECT 1"); + Assert.That(conn.HasTransaction); + } + await using (var conn = await dataSource.OpenConnectionAsync() as SpannerConnection) + { + Assert.False(conn!.HasTransaction); + } + } + + [Test] + public async Task ReadLargeString() + { + const string sql = "select large_value from my_table"; + var value = TestUtils.GenerateRandomString(10_000_000); + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateSingleColumnResultSet( + new V1.Type{Code = TypeCode.String}, "large_value", value)); + + await using var dataSource = CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync() as SpannerConnection; + var got = await conn!.ExecuteScalarAsync(sql); + Assert.That(got, Is.EqualTo(value)); + } + +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-tests/DataSourceTests.cs b/drivers/spanner-ado-net/spanner-ado-net-tests/DataSourceTests.cs new file mode 100644 index 00000000..02fc56f5 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-tests/DataSourceTests.cs @@ -0,0 +1,296 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Data; +using System.Data.Common; +using System.Diagnostics.CodeAnalysis; +using Google.Cloud.SpannerLib.MockServer; + +namespace Google.Cloud.Spanner.DataProvider.Tests; + +public class DataSourceTests : AbstractMockServerTests +{ + [Test] + public async Task CreateConnection() + { + await using var dataSource = SpannerDataSource.Create(ConnectionString); + await using var connection = dataSource.CreateConnection(); + Assert.That(connection.State, Is.EqualTo(ConnectionState.Closed)); + + await connection.OpenAsync(); + Assert.That(await connection.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task OpenConnection([Values] bool async) + { + await using var dataSource = SpannerDataSource.Create(ConnectionString); + await using var connection = async + ? await dataSource.OpenConnectionAsync() + : dataSource.OpenConnection(); + + Assert.That(connection.State, Is.EqualTo(ConnectionState.Open)); + Assert.That(await connection.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task ExecuteScalarOnConnectionlessCommand([Values] bool async) + { + await using var dataSource = SpannerDataSource.Create(ConnectionString); + await using var command = dataSource.CreateCommand(); + command.CommandText = "SELECT 1"; + + if (async) + { + Assert.That(await command.ExecuteScalarAsync(), Is.EqualTo(1)); + } + else + { + Assert.That(command.ExecuteScalar(), Is.EqualTo(1)); + } + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task ExecuteNonQueryOnConnectionlessCommand([Values] bool async) + { + await using var dataSource = SpannerDataSource.Create(ConnectionString); + await using var command = dataSource.CreateCommand(); + command.CommandText = "SELECT 1"; + + if (async) + { + Assert.That(await command.ExecuteNonQueryAsync(), Is.EqualTo(-1)); + } + else + { + Assert.That(command.ExecuteNonQuery(), Is.EqualTo(-1)); + } + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task ExecuteReaderOnConnectionlessCommand([Values] bool async) + { + await using var dataSource = SpannerDataSource.Create(ConnectionString); + await using var command = dataSource.CreateCommand(); + command.CommandText = "SELECT 1"; + + await using var reader = async ? await command.ExecuteReaderAsync() : command.ExecuteReader(); + Assert.That(reader.Read()); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + } + + [Ignore("Requires support for batching queries")] + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task ExecuteScalarOnConnectionlessBatch([Values] bool async) + { + await using var dataSource = SpannerDataSource.Create(ConnectionString); + await using var batch = dataSource.CreateBatch(); + batch.AddSpannerBatchCommand("SELECT 1"); + batch.AddSpannerBatchCommand("SELECT 2"); + + if (async) + { + Assert.That(await batch.ExecuteScalarAsync(), Is.EqualTo(1)); + } + else + { + Assert.That(batch.ExecuteScalar(), Is.EqualTo(1)); + } + } + + [Ignore("Requires support for batching queries")] + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task ExecuteNonQueryOnConnectionlessBatch([Values] bool async) + { + await using var dataSource = SpannerDataSource.Create(ConnectionString); + await using var batch = dataSource.CreateBatch(); + batch.AddSpannerBatchCommand("SELECT 1"); + batch.AddSpannerBatchCommand("SELECT 2"); + + if (async) + { + Assert.That(await batch.ExecuteNonQueryAsync(), Is.EqualTo(-1)); + } + else + { + Assert.That(batch.ExecuteNonQuery(), Is.EqualTo(-1)); + } + } + + [Ignore("Requires support for batching queries")] + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task ExecuteReaderOnConnectionlessBatch([Values] bool async) + { + await using var dataSource = SpannerDataSource.Create(ConnectionString); + await using var batch = dataSource.CreateBatch(); + batch.AddSpannerBatchCommand("SELECT 1"); + batch.AddSpannerBatchCommand("SELECT 2"); + + await using var reader = async ? await batch.ExecuteReaderAsync() : batch.ExecuteReader(); + Assert.That(reader.Read()); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + Assert.That(reader.NextResult()); + Assert.That(reader.Read()); + Assert.That(reader.GetInt32(0), Is.EqualTo(2)); + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task ExecuteScalarOnConnectionlessDmlBatch([Values] bool async) + { + const string sql = "insert into my_table (id) values (default)"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateUpdateCount(1)); + + await using var dataSource = SpannerDataSource.Create(ConnectionString); + await using var batch = dataSource.CreateBatch(); + batch.AddSpannerBatchCommand(sql); + batch.AddSpannerBatchCommand(sql); + + if (async) + { + Assert.ThrowsAsync(async () => await batch.ExecuteScalarAsync()); + } + else + { + Assert.Throws(() => batch.ExecuteScalar()); + } + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task ExecuteNonQueryOnConnectionlessDmlBatch([Values] bool async) + { + const string sql = "insert into my_table (id) values (default)"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateUpdateCount(1)); + + await using var dataSource = SpannerDataSource.Create(ConnectionString); + await using var batch = dataSource.CreateBatch(); + batch.AddSpannerBatchCommand(sql); + batch.AddSpannerBatchCommand(sql); + + if (async) + { + Assert.That(await batch.ExecuteNonQueryAsync(), Is.EqualTo(2)); + } + else + { + Assert.That(batch.ExecuteNonQuery(), Is.EqualTo(2)); + } + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task ExecuteReaderOnConnectionlessDmlBatch([Values] bool async) + { + const string sql = "insert into my_table (id) values (default)"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateUpdateCount(1)); + + await using var dataSource = SpannerDataSource.Create(ConnectionString); + await using var batch = dataSource.CreateBatch(); + batch.AddSpannerBatchCommand(sql); + batch.AddSpannerBatchCommand(sql); + + if (async) + { + Assert.ThrowsAsync(async () => await batch.ExecuteReaderAsync()); + } + else + { + Assert.Throws(() => batch.ExecuteReader()); + } + } + + [Test] + public void Dispose() + { + var dataSource = SpannerDataSource.Create(ConnectionString); + var connection1 = dataSource.OpenConnection(); + var connection2 = dataSource.OpenConnection(); + connection1.Close(); + + // SpannerDataSource does not contain any state, so disposing it is a no-op. + dataSource.Dispose(); + using var connection3 = dataSource.OpenConnection(); + + connection2.Close(); + } + + [Test] + public async Task DisposeAsync() + { + var dataSource = SpannerDataSource.Create(ConnectionString); + var connection1 = await dataSource.OpenConnectionAsync(); + var connection2 = await dataSource.OpenConnectionAsync(); + await connection1.CloseAsync(); + + // SpannerDataSource does not contain any state, so disposing it is a no-op. + await dataSource.DisposeAsync(); + await using var connection3 = await dataSource.OpenConnectionAsync(); + + await connection2.CloseAsync(); + } + + [Test] + public async Task CannotAccessConnectionTransactionOnDataSourceCommand() + { + await using var command = DataSource.CreateCommand(); + + Assert.That(() => command.Connection, Throws.Exception.TypeOf()); + Assert.That(() => command.Connection = null, Throws.Exception.TypeOf()); + Assert.That(() => command.Transaction, Throws.Exception.TypeOf()); + Assert.That(() => command.Transaction = null, Throws.Exception.TypeOf()); + + Assert.That(() => command.Prepare(), Throws.Exception.TypeOf()); + Assert.That(() => command.PrepareAsync(), Throws.Exception.TypeOf()); + } + + [Test] + public async Task CannotAccessConnectionTransactionOnDataSourceBatch() + { + await using var batch = DataSource.CreateBatch(); + + Assert.That(() => batch.Connection, Throws.Exception.TypeOf()); + Assert.That(() => batch.Connection = null, Throws.Exception.TypeOf()); + Assert.That(() => batch.Transaction, Throws.Exception.TypeOf()); + Assert.That(() => batch.Transaction = null, Throws.Exception.TypeOf()); + + Assert.That(() => batch.Prepare(), Throws.Exception.TypeOf()); + Assert.That(() => batch.PrepareAsync(), Throws.Exception.TypeOf()); + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task AsDbDataSource([Values] bool async) + { + await using DbDataSource dataSource = SpannerDataSource.Create(ConnectionString); + await using var connection = async + ? await dataSource.OpenConnectionAsync() + : dataSource.OpenConnection(); + Assert.That(connection.State, Is.EqualTo(ConnectionState.Open)); + + await using var command = dataSource.CreateCommand("SELECT 1"); + + Assert.That(async + ? await command.ExecuteScalarAsync() + : command.ExecuteScalar(), Is.EqualTo(1)); + } + +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-tests/README.md b/drivers/spanner-ado-net/spanner-ado-net-tests/README.md new file mode 100644 index 00000000..8e93c1e4 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-tests/README.md @@ -0,0 +1,5 @@ +# Spanner ADO.NET Data Provider Tests + +Tests for ADO.NET Data Provider for Spanner. + +__ALPHA: Not for production use__ diff --git a/drivers/spanner-ado-net/spanner-ado-net-tests/ReaderTests.cs b/drivers/spanner-ado-net/spanner-ado-net-tests/ReaderTests.cs new file mode 100644 index 00000000..b3de813f --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-tests/ReaderTests.cs @@ -0,0 +1,1469 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Buffers.Binary; +using System.Collections; +using System.Data; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.InteropServices; +using System.Text; +using Google.Cloud.Spanner.V1; +using Google.Cloud.SpannerLib.MockServer; +using Grpc.Core; +using TypeCode = Google.Cloud.Spanner.V1.TypeCode; + +namespace Google.Cloud.Spanner.DataProvider.Tests; + +public class ReaderTests : AbstractMockServerTests +{ + [Test] + public async Task ResumableNonConsumedToNonResumable() + { + var base64Value = Convert.ToBase64String([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + var sql = $"SELECT from_base64('{base64Value}'), 1"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.Bytes, "c"), Tuple.Create(TypeCode.Int64, "c")], + [[base64Value, 1L]])); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(sql, conn); + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + await reader.IsDBNullAsync(0); + _ = reader.IsDBNull(0); + await using var stream = reader.GetStream(0); + Assert.That(reader.GetString(0), Is.EqualTo(base64Value)); + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task SeekColumns() + { + const string sql = "SELECT 1,2,3"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.Int64, "c"), Tuple.Create(TypeCode.Int64, "c"), Tuple.Create(TypeCode.Int64, "c")], + [[1,2,3]])); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(sql, conn); + await using var reader = await cmd.ExecuteReaderAsync(); + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + Assert.That(reader.GetInt32(1), Is.EqualTo(2)); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + } + + [Ignore("Requires multi-statement support")] + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task NoResultSet() + { + const string insert = "INSERT INTO my_table VALUES (8)"; + Fixture.SpannerMock.AddOrUpdateStatementResult(insert, StatementResult.CreateUpdateCount(1L)); + + await using var conn = await OpenConnectionAsync(); + + await using (var cmd = new SpannerCommand(insert, conn)) + await using (var reader = await cmd.ExecuteReaderAsync()) + { + Assert.That(() => reader.GetOrdinal("foo"), Throws.Exception.TypeOf()); + Assert.That(reader.Read(), Is.False); + Assert.That(() => reader.GetOrdinal("foo"), Throws.Exception.TypeOf()); + Assert.That(reader.FieldCount, Is.EqualTo(0)); + Assert.That(reader.NextResult(), Is.False); + Assert.That(() => reader.GetOrdinal("foo"), Throws.Exception.TypeOf()); + } + + await using (var cmd = new SpannerCommand($"SELECT 1; {insert}", conn)) + await using (var reader = await cmd.ExecuteReaderAsync()) + { + await reader.NextResultAsync(); + Assert.That(() => reader.GetOrdinal("foo"), Throws.Exception.TypeOf()); + Assert.That(reader.Read(), Is.False); + Assert.That(() => reader.GetOrdinal("foo"), Throws.Exception.TypeOf()); + Assert.That(reader.FieldCount, Is.EqualTo(0)); + } + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task EmptyResultSet() + { + const string sql = "SELECT 1 AS foo WHERE FALSE"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.Int64, "foo")], [])); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(sql, conn); + await using var reader = await cmd.ExecuteReaderAsync(); + Assert.That(reader.Read(), Is.False); + Assert.That(reader.FieldCount, Is.EqualTo(1)); + Assert.That(reader.GetOrdinal("foo"), Is.EqualTo(0)); + Assert.That(() => reader[0], Throws.Exception.TypeOf()); + } + + [Ignore("Requires multi-statement support")] + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task FieldCount() + { + await using var conn = await OpenConnectionAsync(); + + await using var cmd = new SpannerCommand("SELECT 1; SELECT 2,3", conn); + await using (var reader = await cmd.ExecuteReaderAsync()) + { + Assert.That(reader.FieldCount, Is.EqualTo(1)); + Assert.That(reader.Read(), Is.True); + Assert.That(reader.FieldCount, Is.EqualTo(1)); + Assert.That(reader.Read(), Is.False); + Assert.That(reader.FieldCount, Is.EqualTo(1)); + Assert.That(reader.NextResult(), Is.True); + Assert.That(reader.FieldCount, Is.EqualTo(2)); + Assert.That(reader.NextResult(), Is.False); + Assert.That(reader.FieldCount, Is.EqualTo(0)); + } + + cmd.CommandText = $"INSERT INTO my_table (int) VALUES (1)"; + await using (var reader = await cmd.ExecuteReaderAsync()) + { + Assert.That(() => reader.FieldCount, Is.EqualTo(0)); + } + } + + [Ignore("Requires multi-statement support")] + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task RecordsAffected() + { + const int insertCount = 15; + for (var i = 0; i < insertCount; i++) + { + Fixture.SpannerMock.AddOrUpdateStatementResult($"INSERT INTO my_table (int) VALUES ({i});", StatementResult.CreateUpdateCount(1)); + } + + await using var conn = await OpenConnectionAsync(); + + var sb = new StringBuilder(); + for (var i = 0; i < 10; i++) + { + sb.Append($"INSERT INTO my_table (int) VALUES ({i});"); + } + // Testing, that on close reader consumes all rows (as insert doesn't have a result set, but select does) + sb.Append("SELECT 1;"); + for (var i = 10; i < 15; i++) + { + sb.Append($"INSERT INTO my_table (int) VALUES ({i});"); + } + + var cmd = new SpannerCommand(sb.ToString(), conn); + var reader = await cmd.ExecuteReaderAsync(); + reader.Close(); + Assert.That(reader.RecordsAffected, Is.EqualTo(insertCount)); + + const string select = "SELECT * FROM my_table"; + Fixture.SpannerMock.AddOrUpdateStatementResult(select, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.Int64, "int")], + [[1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14], [15]])); + cmd = new SpannerCommand(select, conn); + reader = await cmd.ExecuteReaderAsync(); + reader.Close(); + Assert.That(reader.RecordsAffected, Is.EqualTo(-1)); + + const string update = "UPDATE my_table SET int=int+1 WHERE int > 10"; + Fixture.SpannerMock.AddOrUpdateStatementResult(update, StatementResult.CreateUpdateCount(5)); + cmd = new SpannerCommand(update, conn); + reader = await cmd.ExecuteReaderAsync(); + reader.Close(); + Assert.That(reader.RecordsAffected, Is.EqualTo(4)); + + const string noUpdate = "UPDATE my_table SET int=8 WHERE int=666"; + Fixture.SpannerMock.AddOrUpdateStatementResult(noUpdate, StatementResult.CreateUpdateCount(0)); + cmd = new SpannerCommand(noUpdate, conn); + reader = await cmd.ExecuteReaderAsync(); + reader.Close(); + Assert.That(reader.RecordsAffected, Is.EqualTo(0)); + + const string delete = "DELETE FROM my_table WHERE int > 10"; + Fixture.SpannerMock.AddOrUpdateStatementResult(delete, StatementResult.CreateUpdateCount(4)); + cmd = new SpannerCommand(delete, conn); + reader = await cmd.ExecuteReaderAsync(); + reader.Close(); + Assert.That(reader.RecordsAffected, Is.EqualTo(4)); + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task GetStringWithParameter() + { + await using var conn = await OpenConnectionAsync(); + const string text = "Random text"; + const string sql = "SELECT name FROM my_table WHERE name = @value;"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "name")], [[text]])); + + var command = new SpannerCommand(sql, conn); + var param = new SpannerParameter + { + ParameterName = "value", + DbType = DbType.String, + Size = text.Length, + Value = text + }; + command.Parameters.Add(param); + + await using var dr = await command.ExecuteReaderAsync(); + dr.Read(); + var result = dr.GetString(0); + Assert.That(result, Is.EqualTo(text)); + } + + [Test] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task GetStringWithQuoteWithParameter() + { + const string test = "Text with ' single quote"; + const string sql = "SELECT name FROM my_table WHERE name = @value;"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "name")], [[test]])); + + using var conn = await OpenConnectionAsync(); + var command = new SpannerCommand(sql, conn); + + var param = new SpannerParameter + { + ParameterName = "value", + DbType = DbType.String, + Size = test.Length, + Value = test + }; + command.Parameters.Add(param); + + using var dr = await command.ExecuteReaderAsync(); + dr.Read(); + var result = dr.GetString(0); + Assert.That(result, Is.EqualTo(test)); + } + + [Test] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task GetValueByName() + { + const string sql = "SELECT 'Random text' AS real_column"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "real_column")], [["Random text"]])); + + await using var conn = await OpenConnectionAsync(); + using var command = new SpannerCommand(sql, conn); + using var dr = await command.ExecuteReaderAsync(); + dr.Read(); + Assert.That(dr["real_column"], Is.EqualTo("Random text")); + Assert.That(() => dr["non_existing"], Throws.Exception.TypeOf()); + } + + [Test] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + [SuppressMessage("ReSharper", "ConvertToUsingDeclaration")] + public async Task GetFieldType() + { + const string sql = "SELECT cast(1 as int64) AS some_column"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.Int64, "some_column")], [[1L]])); + + using var conn = await OpenConnectionAsync(); + using (var cmd = new SpannerCommand(sql, conn)) + using (var reader = await cmd.ExecuteReaderAsync()) + { + reader.Read(); + Assert.That(reader.GetFieldType(0), Is.SameAs(typeof(long))); + } + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task GetFieldType_SchemaOnly() + { + const string sql = "SELECT cast(1 as int64) AS some_column"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.Int64, "some_column")], [[]])); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(sql, conn); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + Assert.False(reader.Read()); + Assert.That(reader.GetFieldType(0), Is.SameAs(typeof(long))); + + var request = Fixture.SpannerMock.Requests.OfType().First(); + Assert.That(request, Is.Not.Null); + Assert.That(request.QueryMode, Is.EqualTo(ExecuteSqlRequest.Types.QueryMode.Plan)); + } + + [SuppressMessage("ReSharper", "UseAwaitUsing")] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + [Test] + public async Task GetDataTypeName([Values] TypeCode typeCode) + { + if (typeCode == TypeCode.Array || typeCode == TypeCode.Unspecified) + { + return; + } + + var sql = $"SELECT cast(NULL as {typeCode.ToString()} AS some_column"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(typeCode, "some_column")], [])); + + using var conn = await OpenConnectionAsync(); + using var cmd = new SpannerCommand(sql, conn); + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + Assert.That(reader.GetDataTypeName(0), Is.EqualTo(typeCode.ToString().ToUpper())); + } + + [Test] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + public async Task GetName() + { + const string sql = "SELECT 1 AS some_column"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.Int64, "some_column")], [[1L]])); + + using var conn = await OpenConnectionAsync(); + using var command = new SpannerCommand(sql, conn); + using var dr = await command.ExecuteReaderAsync(); + await dr.ReadAsync(); + Assert.That(dr.GetName(0), Is.EqualTo("some_column")); + } + + [Test] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task GetFieldValueAsObject() + { + const string sql = "SELECT 'foo'"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "c")], [["foo"]])); + + using var conn = await OpenConnectionAsync(); + using var cmd = new SpannerCommand(sql, conn); + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + Assert.That(reader.GetFieldValue(0), Is.EqualTo("foo")); + } + + [Test] + public async Task GetValues() + { + const string sql = "SELECT 'hello', 1, cast('2014-01-01' as DATE)"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [ + Tuple.Create(TypeCode.String, "c1"), + Tuple.Create(TypeCode.Int64, "c2"), + Tuple.Create(TypeCode.Date, "c3"), + ], [["hello", 1L, new DateOnly(2014, 1, 1)]])); + + await using var conn = await OpenConnectionAsync(); + await using var command = new SpannerCommand(sql, conn); + await using (var dr = await command.ExecuteReaderAsync()) + { + await dr.ReadAsync(); + var values = new object[4]; + Assert.That(dr.GetValues(values), Is.EqualTo(3)); + Assert.That(values, Is.EqualTo(new object?[] { "hello", 1, new DateOnly(2014, 1, 1), null })); + } + await using (var dr = await command.ExecuteReaderAsync()) + { + await dr.ReadAsync(); + var values = new object[2]; + Assert.That(dr.GetValues(values), Is.EqualTo(2)); + Assert.That(values, Is.EqualTo(new object[] { "hello", 1 })); + } + } + + [Test] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task ExecuteReaderGettingEmptyResultSetWithOutputParameter() + { + const string sql = "SELECT * FROM my_table WHERE name = NULL;"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "c")], [])); + + using var conn = await OpenConnectionAsync(); + var command = new SpannerCommand(sql, conn); + var param = new SpannerParameter("some_param", DbType.String) + { + Direction = ParameterDirection.Output + }; + command.Parameters.Add(param); + using var dr = await command.ExecuteReaderAsync(); + Assert.That(dr.NextResult(), Is.False); + } + + [Test] + public async Task GetValueFromEmptyResultSet() + { + const string sql = "SELECT * FROM my_table WHERE name = :value;"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "name")], [])); + + await using var conn = await OpenConnectionAsync(); + await using var command = new SpannerCommand(sql, conn); + const string test = "Text single quote"; + var param = new SpannerParameter + { + ParameterName = "value", + DbType = DbType.String, + Size = test.Length, + Value = test + }; + command.Parameters.Add(param); + + await using var dr = await command.ExecuteReaderAsync(); + Assert.False(await dr.ReadAsync()); + // This line should throw the invalid operation exception as the data reader will + // have an empty resultset. + Assert.That(() => Console.WriteLine(dr.IsDBNull(0)), + Throws.Exception.TypeOf()); + } + + [Test] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task ReadPastReaderEnd() + { + using var conn = await OpenConnectionAsync(); + var command = new SpannerCommand("SELECT 1", conn); + using var dr = await command.ExecuteReaderAsync(); + while (dr.Read()) {} + Assert.That(() => dr[0], Throws.Exception.TypeOf()); + } + + [Ignore("Require multi-statement support")] + [Test] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task SingleResult() + { + await using var conn = await OpenConnectionAsync(); + await using var command = new SpannerCommand("SELECT 1; SELECT 2", conn); + var reader = await command.ExecuteReaderAsync(CommandBehavior.SingleResult); + Assert.That(reader.Read(), Is.True); + Assert.That(reader.GetInt32(0), Is.EqualTo(1)); + Assert.That(reader.NextResult(), Is.False); + } + + [Test] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task Exception_thrown_from_ExecuteReaderAsync([Values(PrepareOrNot.Prepared, PrepareOrNot.NotPrepared)] PrepareOrNot prepare) + { + const string sql = "SELECT error('test')"; + + using var conn = await OpenConnectionAsync(); + + using var cmd = new SpannerCommand(sql, conn); + if (prepare == PrepareOrNot.Prepared) + { + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "error")], [])); + cmd.Prepare(); + } + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateException(new RpcException(new Status(StatusCode.Internal, "test")))); + Assert.That(async () => await cmd.ExecuteReaderAsync(), Throws.Exception.TypeOf()); + } + + [Ignore("Require multi-statement support")] + [Test] + public async Task Exception_thrown_from_NextResult([Values(PrepareOrNot.Prepared, PrepareOrNot.NotPrepared)] PrepareOrNot prepare) + { + const string select1 = "SELECT 1"; + const string selectError = "SELECT error('test')"; + + await using var conn = await OpenConnectionAsync(); + + await using var cmd = new SpannerCommand($"{select1}; {selectError}", conn); + if (prepare == PrepareOrNot.Prepared) + { + Fixture.SpannerMock.AddOrUpdateStatementResult(selectError, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "error")], [])); + await cmd.PrepareAsync(); + } + + Fixture.SpannerMock.AddOrUpdateStatementResult(selectError, StatementResult.CreateException(new RpcException(new Status(StatusCode.Internal, "test")))); + await using var reader = await cmd.ExecuteReaderAsync(); + Assert.That(() => reader.NextResult(), Throws.Exception.TypeOf()); + } + + [Test] + public async Task SchemaOnlyReturnsNoData() + { + const string sql = "SELECT * FROM my_table"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "c")], [])); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(sql, conn); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + Assert.That(await reader.ReadAsync(), Is.False); + + var request = Fixture.SpannerMock.Requests.OfType().First(); + Assert.That(request, Is.Not.Null); + Assert.That(request.QueryMode, Is.EqualTo(ExecuteSqlRequest.Types.QueryMode.Plan)); + } + + [Test] + public async Task SchemaOnlyNextResultBeyondEnd() + { + const string sql = "SELECT * FROM my_table"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.Int64, "id")], [])); + + await using var conn = await OpenConnectionAsync(); + + await using var cmd = new SpannerCommand(sql, conn); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + Assert.That(await reader.NextResultAsync(), Is.False); + Assert.That(await reader.NextResultAsync(), Is.False); + } + + [Test] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + public async Task GetOrdinal() + { + const string sql = "SELECT 0, 1 AS some_column WHERE 1=0"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.Int64, "c"), Tuple.Create(TypeCode.Int64, "some_column")], [])); + + using var conn = await OpenConnectionAsync(); + using var command = new SpannerCommand(sql, conn); + using var reader = await command.ExecuteReaderAsync(); + Assert.That(reader.GetOrdinal("some_column"), Is.EqualTo(1)); + Assert.That(() => reader.GetOrdinal("doesn't_exist"), Throws.Exception.TypeOf()); + } + + [Test] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + public async Task GetOrdinalCaseInsensitive() + { + const string sql = "select 123 as FIELD1"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.Int64, "FIELD1")], [[123L]])); + + using var conn = await OpenConnectionAsync(); + using var command = new SpannerCommand(sql, conn); + using var reader = await command.ExecuteReaderAsync(); + await reader.ReadAsync(); + Assert.That(reader.GetOrdinal("fieLd1"), Is.EqualTo(0)); + } + + [Test] + public async Task FieldIndexDoesNotExist() + { + await using var conn = await OpenConnectionAsync(); + await using var command = new SpannerCommand("SELECT 1", conn); + await using var dr = await command.ExecuteReaderAsync(); + await dr.ReadAsync(); + Assert.That(() => dr[5], Throws.Exception.TypeOf()); + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task ReaderStillOpen_CanExecuteMoreCommands() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd1 = new SpannerCommand("SELECT 1", conn); + await using var reader1 = await cmd1.ExecuteReaderAsync(); + Assert.That(conn.ExecuteNonQuery("SELECT 1"), Is.EqualTo(-1)); + Assert.That(await conn.ExecuteNonQueryAsync("SELECT 1"), Is.EqualTo(-1)); + Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task CleansUpOkWithDisposeCalls([Values(PrepareOrNot.Prepared, PrepareOrNot.NotPrepared)] PrepareOrNot prepare) + { + await using var conn = await OpenConnectionAsync(); + await using var command = new SpannerCommand("SELECT 1", conn); + await using var dr = await command.ExecuteReaderAsync(); + await dr.ReadAsync(); + dr.Close(); + + await using var upd = conn.CreateCommand(); + upd.CommandText = "SELECT 1"; + if (prepare == PrepareOrNot.Prepared) + { + upd.Prepare(); + } + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + public async Task Null() + { + const string sql = "SELECT @p1, cast(@p2 as string)"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "p1"), Tuple.Create(TypeCode.String, "p2")], [[DBNull.Value, DBNull.Value]])); + + using var conn = await OpenConnectionAsync(); + using var cmd = new SpannerCommand(sql, conn); + cmd.Parameters.Add(new SpannerParameter("p1", DbType.String) { Value = DBNull.Value }); + cmd.Parameters.Add(new SpannerParameter { ParameterName = "p2", Value = DBNull.Value }); + + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + + for (var i = 0; i < cmd.Parameters.Count; i++) + { + Assert.That(reader.IsDBNull(i), Is.True); + Assert.That(reader.IsDBNullAsync(i).Result, Is.True); + Assert.That(reader.GetValue(i), Is.EqualTo(DBNull.Value)); + Assert.That(reader.GetFieldValue(i), Is.EqualTo(DBNull.Value)); + Assert.That(reader.GetProviderSpecificValue(i), Is.EqualTo(DBNull.Value)); + Assert.That(() => reader.GetString(i), Throws.Exception.TypeOf()); + } + } + + [Ignore("Requires multi-statement support")] + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + public async Task HasRows([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) + { + using var conn = await OpenConnectionAsync(); + + var command = new SpannerCommand($"SELECT 1; SELECT * FROM my_table WHERE name='does_not_exist'", conn); + if (prepare == PrepareOrNot.Prepared) + { + command.Prepare(); + } + using (var reader = await command.ExecuteReaderAsync()) + { + Assert.That(reader.HasRows, Is.True); + Assert.That(reader.HasRows, Is.True); + Assert.That(reader.Read(), Is.True); + Assert.That(reader.HasRows, Is.True); + Assert.That(reader.Read(), Is.False); + Assert.That(reader.HasRows, Is.True); + await reader.NextResultAsync(); + Assert.That(reader.HasRows, Is.False); + } + + command.CommandText = "SELECT * FROM my_table"; + if (prepare == PrepareOrNot.Prepared) + { + command.Prepare(); + } + using (var reader = await command.ExecuteReaderAsync()) + { + reader.Read(); + Assert.That(reader.HasRows, Is.False); + } + + command.CommandText = "SELECT 1"; + if (prepare == PrepareOrNot.Prepared) + { + command.Prepare(); + } + using (var reader = await command.ExecuteReaderAsync()) + { + reader.Read(); + reader.Close(); + Assert.Throws(() => _ = reader.HasRows); + } + + command.CommandText = $"INSERT INTO my_table (name) VALUES ('foo'); SELECT * FROM my_table"; + if (prepare == PrepareOrNot.Prepared) + { + command.Prepare(); + } + using (var reader = await command.ExecuteReaderAsync()) + { + Assert.That(reader.HasRows, Is.True); + reader.Read(); + Assert.That(reader.GetString(0), Is.EqualTo("foo")); + } + + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + public async Task HasRowsSingleStatement([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) + { + const string selectNoRows = "SELECT * FROM my_table WHERE name='does_not_exist'"; + Fixture.SpannerMock.AddOrUpdateStatementResult(selectNoRows, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "name")], [])); + + using var conn = await OpenConnectionAsync(); + + var command = new SpannerCommand("SELECT 1", conn); + if (prepare == PrepareOrNot.Prepared) + { + command.Prepare(); + } + using (var reader = await command.ExecuteReaderAsync()) + { + Assert.That(reader.HasRows, Is.True); + Assert.That(reader.HasRows, Is.True); + Assert.That(reader.Read(), Is.True); + Assert.That(reader.HasRows, Is.True); + Assert.That(reader.Read(), Is.False); + Assert.That(reader.HasRows, Is.True); + Assert.False(await reader.NextResultAsync()); + } + + command.CommandText = selectNoRows; + if (prepare == PrepareOrNot.Prepared) + { + command.Prepare(); + } + using (var reader = await command.ExecuteReaderAsync()) + { + Assert.That(reader.HasRows, Is.False); + Assert.That(reader.HasRows, Is.False); + Assert.That(reader.Read(), Is.False); + Assert.That(reader.HasRows, Is.False); + Assert.That(reader.Read(), Is.False); + Assert.That(reader.HasRows, Is.False); + Assert.False(await reader.NextResultAsync()); + } + + command.CommandText = "SELECT 1"; + if (prepare == PrepareOrNot.Prepared) + { + command.Prepare(); + } + using (var reader = await command.ExecuteReaderAsync()) + { + reader.Read(); + reader.Close(); + Assert.Throws(() => _ = reader.HasRows); + } + + const string insertRow = "INSERT INTO my_table (name) VALUES ('foo')"; + Fixture.SpannerMock.AddOrUpdateStatementResult(insertRow, StatementResult.CreateUpdateCount(1L)); + command.CommandText = insertRow; + if (prepare == PrepareOrNot.Prepared) + { + command.Prepare(); + } + using (var reader = await command.ExecuteReaderAsync()) + { + Assert.That(reader.HasRows, Is.False); + Assert.False(reader.Read()); + } + + const string selectRow = "SELECT * FROM my_table"; + Fixture.SpannerMock.AddOrUpdateStatementResult(selectRow, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "name")], [["foo"]])); + command.CommandText = selectRow; + if (prepare == PrepareOrNot.Prepared) + { + command.Prepare(); + } + using (var reader = await command.ExecuteReaderAsync()) + { + Assert.That(reader.HasRows, Is.True); + Assert.True(reader.Read()); + Assert.That(reader.GetString(0), Is.EqualTo("foo")); + } + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + + [Test] + public async Task HasRowsWithoutResultSet() + { + const string sql = "DELETE FROM my_table WHERE name = 'unknown'"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateUpdateCount(0L)); + + await using var conn = await OpenConnectionAsync(); + await using var command = new SpannerCommand(sql, conn); + await using var reader = await command.ExecuteReaderAsync(); + Assert.That(reader.HasRows, Is.False); + } + + [Test] + public async Task IntervalAsTimeSpan() + { + const string sql = "SELECT CAST('1 hour' AS interval) AS dauer"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.Interval, "dauer")], [["PT1H"]])); + + await using var conn = await OpenConnectionAsync(); + await using var command = new SpannerCommand(sql, conn); + await using var dr = await command.ExecuteReaderAsync() as SpannerDataReader; + Assert.That(dr!.HasRows); + Assert.That(await dr.ReadAsync()); + Assert.That(dr.HasRows); + var ts = dr.GetTimeSpan(0); + Assert.That(ts, Is.EqualTo(TimeSpan.FromHours(1))); + } + + [Test] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task CloseConnectionInMiddleOfRow() + { + const string sql = "SELECT 1, 2"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.Int64, "c"), Tuple.Create(TypeCode.Int64, "c")], [[1L, 2L]])); + + using var conn = await OpenConnectionAsync(); + using var cmd = new SpannerCommand(sql, conn); + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + } + + [Test] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task InvalidCast() + { + const string sql = "SELECT 'foo'"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "c")], [["foo"]])); + + using var conn = await OpenConnectionAsync(); + using (var cmd = new SpannerCommand(sql, conn)) + using (var reader = await cmd.ExecuteReaderAsync()) + { + reader.Read(); + Assert.That(() => reader.GetInt32(0), Throws.Exception.TypeOf()); + } + + using (var cmd = new SpannerCommand("SELECT 1", conn)) + using (var reader = await cmd.ExecuteReaderAsync()) + { + reader.Read(); + Assert.That(() => reader.GetDateTime(0), Throws.Exception.TypeOf()); + } + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + } + + [Test] + public async Task NullableScalar() + { + const string sql = "SELECT @p1, @p2"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.Int64, "p1"), Tuple.Create(TypeCode.Int64, "p2")], [[DBNull.Value, 8L]])); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(sql, conn); + var p1 = new SpannerParameter { ParameterName = "p1", Value = DBNull.Value, DbType = DbType.Int16 }; + var p2 = new SpannerParameter { ParameterName = "p2", Value = (short)8 }; + Assert.That(p1.DbType, Is.EqualTo(DbType.Int16)); + Assert.That(p2.DbType, Is.EqualTo(DbType.String)); // This is the ADO.NET default + cmd.Parameters.Add(p1); + cmd.Parameters.Add(p2); + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + for (var i = 0; i < cmd.Parameters.Count; i++) + { + Assert.That(reader.GetFieldType(i), Is.EqualTo(typeof(long))); + Assert.That(reader.GetDataTypeName(i), Is.EqualTo("INT64")); + } + + Assert.That(() => reader.GetFieldValue(0), Is.EqualTo(DBNull.Value)); + Assert.That(() => reader.GetFieldValue(0), Throws.TypeOf()); + Assert.That(() => reader.GetFieldValue(0), Throws.Nothing); + Assert.That(reader.GetFieldValue(0), Is.Null); + + Assert.That(() => reader.GetFieldValue(1), Throws.Nothing); + Assert.That(() => reader.GetFieldValue(1), Throws.Nothing); + Assert.That(() => reader.GetFieldValue(1), Throws.Nothing); + Assert.That(reader.GetFieldValue(1), Is.EqualTo(8)); + Assert.That(reader.GetFieldValue(1), Is.EqualTo(8)); + Assert.That(reader.GetFieldValue(1), Is.EqualTo(8)); + } + + [Test] + public async Task ReaderCloseAndDispose() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd1 = conn.CreateCommand(); + cmd1.CommandText = "SELECT 1"; + + var reader1 = await cmd1.ExecuteReaderAsync(CommandBehavior.CloseConnection); + await reader1.CloseAsync(); + + await conn.OpenAsync(); + cmd1.Connection = conn; + var reader2 = await cmd1.ExecuteReaderAsync(CommandBehavior.CloseConnection); + Assert.That(reader1, Is.Not.SameAs(reader2)); + Assert.Throws(() => _ = reader2.GetInt64(0)); + + await reader1.DisposeAsync(); + + Assert.Throws(() => _ = reader2.GetInt64(0)); + } + + [Test] + public async Task ConnectionCloseAndReaderDispose() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd1 = conn.CreateCommand(); + cmd1.CommandText = "SELECT 1"; + + var reader1 = await cmd1.ExecuteReaderAsync(); + await conn.CloseAsync(); + await conn.OpenAsync(); + + var reader2 = await cmd1.ExecuteReaderAsync(); + Assert.That(reader1, Is.Not.SameAs(reader2)); + Assert.Throws(() => _ = reader2.GetInt64(0)); + + await reader1.DisposeAsync(); + + Assert.Throws(() => _ = reader2.GetInt64(0)); + } + + [Test] + public async Task UnboundReaderReuse() + { + Fixture.SpannerMock.AddOrUpdateStatementResult("SELECT 2", StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.Int64, "c")], [[2L]])); + Fixture.SpannerMock.AddOrUpdateStatementResult("SELECT 3", StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.Int64, "c")], [[3L]])); + + await using var dataSource = CreateDataSource(csb => { }); + await using var conn1 = await dataSource.OpenConnectionAsync(); + await using var cmd1 = conn1.CreateCommand(); + cmd1.CommandText = "SELECT 1"; + var reader1 = await cmd1.ExecuteReaderAsync(); + await using (var __ = reader1) + { + Assert.That(async () => await reader1.ReadAsync(), Is.EqualTo(true)); + Assert.That(() => reader1.GetInt32(0), Is.EqualTo(1)); + + await reader1.CloseAsync(); + await conn1.CloseAsync(); + } + + await using var conn2 = await dataSource.OpenConnectionAsync(); + await using var cmd2 = conn2.CreateCommand(); + cmd2.CommandText = "SELECT 2"; + var reader2 = await cmd2.ExecuteReaderAsync(); + await using (var __ = reader2) + { + Assert.That(async () => await reader2.ReadAsync(), Is.EqualTo(true)); + Assert.That(() => reader2.GetInt32(0), Is.EqualTo(2)); + Assert.That(reader1, Is.Not.SameAs(reader2)); + + await reader2.CloseAsync(); + await conn2.CloseAsync(); + } + + await using var conn3 = await dataSource.OpenConnectionAsync(); + await using var cmd3 = conn3.CreateCommand(); + cmd3.CommandText = "SELECT 3"; + var reader3 = await cmd3.ExecuteReaderAsync(); + await using (var __ = reader3) + { + Assert.That(async () => await reader3.ReadAsync(), Is.EqualTo(true)); + Assert.That(() => reader3.GetInt32(0), Is.EqualTo(3)); + Assert.That(reader1, Is.Not.SameAs(reader3)); + + await reader3.CloseAsync(); + await conn3.CloseAsync(); + } + } + + [Test] + public async Task ReadStringAsChar() + { + const string sql = "SELECT 'abcdefgh', 'ijklmnop'"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "c1"), Tuple.Create(TypeCode.String, "c2")], [["abcdefgh", "ijklmnop"]])); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = conn.CreateCommand(); + cmd.CommandText = sql; + + await using var reader = await cmd.ExecuteReaderAsync(); + Assert.That(await reader.ReadAsync()); + Assert.That(reader.GetChar(0), Is.EqualTo('a')); + Assert.That(reader.GetChar(0), Is.EqualTo('a')); + Assert.That(reader.GetChar(1), Is.EqualTo('i')); + } + + [Test] + public async Task GetBytes() + { + byte[] expected = [1, 2, 3, 4, 5]; + var base64 = Convert.ToBase64String(expected); + const string query = "SELECT bytes, 'foo', bytes, 'bar', bytes, bytes FROM my_table"; + Fixture.SpannerMock.AddOrUpdateStatementResult(query, StatementResult.CreateResultSet( + [ + Tuple.Create(TypeCode.Bytes, "bytes"), + Tuple.Create(TypeCode.String, "foo"), + Tuple.Create(TypeCode.Bytes, "bytes"), + Tuple.Create(TypeCode.String, "bar"), + Tuple.Create(TypeCode.Bytes, "bytes"), + Tuple.Create(TypeCode.Bytes, "bytes"), + ], [[ + base64, + "foo", + base64, + "bar", + base64, + base64, + ]])); + + await using var conn = await OpenConnectionAsync(); + var actual = new byte[expected.Length]; + + await using var cmd = new SpannerCommand(query, conn); + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + Assert.That(reader.GetBytes(0, 0, actual, 0, 2), Is.EqualTo(2)); + Assert.That(actual[0], Is.EqualTo(expected[0])); + Assert.That(actual[1], Is.EqualTo(expected[1])); + Assert.That(reader.GetBytes(0, 0, null, 0, 0), Is.EqualTo(expected.Length), "Bad column length"); + + Assert.That(reader.GetBytes(0, 0, actual, 4, 1), Is.EqualTo(1)); + Assert.That(actual[4], Is.EqualTo(expected[0])); + + Assert.That(reader.GetBytes(0, 2, actual, 2, 3), Is.EqualTo(3)); + Assert.That(actual, Is.EqualTo(expected)); + Assert.That(reader.GetBytes(0, 0, null, 0, 0), Is.EqualTo(expected.Length), "Bad column length"); + + Assert.That(reader.GetString(1), Is.EqualTo("foo")); + reader.GetBytes(2, 0, actual, 0, 2); + + // Jump to another column from the middle of the column + reader.GetBytes(4, 0, actual, 0, 2); + Assert.That(reader.GetBytes(4, expected.Length - 1, actual, 0, 2), Is.EqualTo(1), + "Length greater than data length"); + Assert.That(actual[0], Is.EqualTo(expected[^1]), "Length greater than data length"); + Assert.That(() => reader.GetBytes(4, 0, actual, 0, actual.Length + 1), + Throws.Exception.TypeOf(), "Length great than output buffer length"); + // Close in the middle of a column + reader.GetBytes(5, 0, actual, 0, 2); + + var result = (byte[]) cmd.ExecuteScalar()!; + Assert.That(result.Length, Is.EqualTo(5)); + } + + [Test] + public async Task GetStreamSecondTimeWorks() + { + var expected = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }; + var base64 = Convert.ToBase64String(expected); + var sql = $"SELECT from_base64({base64})"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.Bytes, "from_base64")], [[base64]])); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(sql, conn); + await using var reader = await cmd.ExecuteReaderAsync(); + + await reader.ReadAsync(); + Assert.That(reader.GetStream(0), Is.Not.Null); + Assert.That(reader.GetStream(0), Is.Not.Null); + } + + public static IEnumerable GetStreamCases() + { + var binary = MemoryMarshal + .AsBytes(Enumerable.Range(0, 1024).ToArray()) + .ToArray(); + yield return (binary, binary); + + var bigBinary = MemoryMarshal + .AsBytes(Enumerable.Range(0, 8193).ToArray()) + .ToArray(); + yield return (bigBinary, bigBinary); + + var bigint = 0xDEADBEEFL; + var bigintBinary = BitConverter.GetBytes( + BitConverter.IsLittleEndian + ? BinaryPrimitives.ReverseEndianness(bigint) + : bigint); + yield return (bigint, bigintBinary); + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + public async Task GetStream( + [Values] bool isAsync, + [ValueSource(nameof(GetStreamCases))] (T Generic, byte[] Binary) value) + { + const string sql = "SELECT @p, @p"; + var base64 = Convert.ToBase64String(value.Binary); + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.Bytes, "p"), Tuple.Create(TypeCode.Bytes, "p")], [[base64, base64]])); + + var expected = value.Binary; + var actual = new byte[expected.Length]; + + using var conn = await OpenConnectionAsync(); + using var cmd = new SpannerCommand(sql, conn); + cmd.Parameters.Add(new SpannerParameter("p", value.Generic)); + using var reader = await cmd.ExecuteReaderAsync(); + + await reader.ReadAsync(); + + using var stream = reader.GetStream(0); + Assert.That(stream.Length, Is.EqualTo(expected.Length)); + + var position = 0; + while (position < actual.Length) + { + if (isAsync) + { + position += await stream.ReadAsync(actual, position, actual.Length - position); + } + else + { + position += stream.Read(actual, position, actual.Length - position); + } + } + Assert.That(actual, Is.EqualTo(expected)); + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + public async Task OpenStreamWhenChangingColumns([Values(true, false)] bool isAsync) + { + var data = new byte[] { 1, 2, 3 }; + var base64 = Convert.ToBase64String(data); + const string sql = "SELECT @p, @p"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.Bytes, "p"), Tuple.Create(TypeCode.Bytes, "p")], [[base64, base64]])); + + using var conn = await OpenConnectionAsync(); + using var cmd = new SpannerCommand(sql, conn); + cmd.Parameters.Add(new SpannerParameter("p", data)); + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + var stream = reader.GetStream(0); + _ = reader.GetValue(1); + Assert.That(() => stream.ReadByte(), Throws.Nothing); + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + public async Task OpenStreamWhenChangingRows([Values(true, false)] bool isAsync) + { + var data = new byte[] { 1, 2, 3 }; + var base64 = Convert.ToBase64String(data); + const string sql = "SELECT @p"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.Bytes, "p")], [[base64]])); + + using var conn = await OpenConnectionAsync(); + using var cmd = new SpannerCommand(sql, conn); + cmd.Parameters.Add(new SpannerParameter("p", data)); + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + var s1 = reader.GetStream(0); + reader.Read(); + Assert.That(() => s1.ReadByte(), Throws.Nothing); + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + public async Task GetBytesWithNull([Values(true, false)] bool isAsync) + { + const string sql = "SELECT bytes FROM my_table"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.Bytes, "p")], [[DBNull.Value]])); + + using var conn = await OpenConnectionAsync(); + var buf = new byte[8]; + using var cmd = new SpannerCommand(sql, conn); + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + Assert.That(reader.IsDBNull(0), Is.True); + Assert.That(() => reader.GetBytes(0, 0, buf, 0, 1), Throws.Exception.TypeOf(), "GetBytes"); + Assert.That(() => reader.GetStream(0), Throws.Exception.TypeOf(), "GetStream"); + Assert.That(() => reader.GetBytes(0, 0, null, 0, 0), Throws.Exception.TypeOf(), "GetBytes with null buffer"); + } + + [Test] + public async Task GetStreamSeek() + { + const string sql = "SELECT 'abcdefgh'"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "c")], [["abcdefgh"]])); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = conn.CreateCommand(); + cmd.CommandText = sql; + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + var buffer = new byte[4]; + + await using var stream = reader.GetStream(0); + Assert.That(stream.CanSeek); + + var seekPosition = stream.Seek(-1, SeekOrigin.End); + Assert.That(seekPosition, Is.EqualTo(stream.Length - 1)); + var read = stream.Read(buffer); + Assert.That(read, Is.EqualTo(1)); + Assert.That(Encoding.ASCII.GetString(buffer, 0, 1), Is.EqualTo("h")); + read = stream.Read(buffer); + Assert.That(read, Is.EqualTo(0)); + + seekPosition = stream.Seek(2, SeekOrigin.Begin); + Assert.That(seekPosition, Is.EqualTo(2)); + read = stream.Read(buffer); + Assert.That(read, Is.EqualTo(buffer.Length)); + Assert.That(Encoding.ASCII.GetString(buffer), Is.EqualTo("cdef")); + + seekPosition = stream.Seek(-3, SeekOrigin.Current); + Assert.That(seekPosition, Is.EqualTo(3)); + read = stream.Read(buffer); + Assert.That(read, Is.EqualTo(buffer.Length)); + Assert.That(Encoding.ASCII.GetString(buffer), Is.EqualTo("defg")); + + stream.Position = 1; + read = stream.Read(buffer); + Assert.That(read, Is.EqualTo(buffer.Length)); + Assert.That(Encoding.ASCII.GetString(buffer), Is.EqualTo("bcde")); + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + public async Task GetChars() + { + const string str = "ABCDE"; + var expected = str.ToCharArray(); + var actual = new char[expected.Length]; + var queryText = $"SELECT '{str}', 3, '{str}', 4, '{str}', '{str}', '{str}'"; + Fixture.SpannerMock.AddOrUpdateStatementResult(queryText, StatementResult.CreateResultSet( + [ + Tuple.Create(TypeCode.String, "str"), + Tuple.Create(TypeCode.Int64, "c"), + Tuple.Create(TypeCode.String, "str"), + Tuple.Create(TypeCode.Int64, "c"), + Tuple.Create(TypeCode.String, "str"), + Tuple.Create(TypeCode.String, "str"), + Tuple.Create(TypeCode.String, "str"), + ], [[ str, 3L, str, 4L, str, str, str, ]])); + + using var conn = await OpenConnectionAsync(); + using var cmd = new SpannerCommand(queryText, conn); + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + + Assert.That(reader.GetChars(0, 0, actual, 0, 2), Is.EqualTo(2)); + Assert.That(actual[0], Is.EqualTo(expected[0])); + Assert.That(actual[1], Is.EqualTo(expected[1])); + Assert.That(reader.GetChars(0, 0, null, 0, 0), Is.EqualTo(expected.Length), "Bad column length"); + + Assert.That(reader.GetChars(2, 0, actual, 0, 2), Is.EqualTo(2)); + Assert.That(reader.GetChars(2, 0, actual, 4, 1), Is.EqualTo(1)); + Assert.That(actual[4], Is.EqualTo(expected[0])); + Assert.That(reader.GetChars(2, 2, actual, 2, 3), Is.EqualTo(3)); + Assert.That(actual, Is.EqualTo(expected)); + + //Assert.That(reader.GetChars(2, 0, null, 0, 0), Is.EqualTo(expected.Length), "Bad column length"); + + Assert.That(() => reader.GetChars(3, 0, null, 0, 0), Throws.Exception.TypeOf(), "GetChars on non-text"); + Assert.That(() => reader.GetChars(3, 0, actual, 0, 1), Throws.Exception.TypeOf(), "GetChars on non-text"); + Assert.That(reader.GetInt32(3), Is.EqualTo(4)); + reader.GetChars(4, 0, actual, 0, 2); + // Jump to another column from the middle of the column + reader.GetChars(5, 0, actual, 0, 2); + Assert.That(reader.GetChars(5, expected.Length - 1, actual, 0, 2), Is.EqualTo(1), "Length greater than data length"); + Assert.That(actual[0], Is.EqualTo(expected[^1]), "Length greater than data length"); + Assert.That(() => reader.GetChars(5, 0, actual, 0, actual.Length + 1), Throws.Exception.TypeOf(), "Length great than output buffer length"); + // Close in the middle of a column + reader.GetChars(6, 0, actual, 0, 2); + } + + [Test] + public async Task GetCharsAdvanceConsumed() + { + const string value = "01234567"; + var sql = $"SELECT '{value}'"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "c")], [[value]])); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(sql, conn); + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + var buffer = new char[2]; + // Don't start at the beginning of the column. + reader.GetChars(0, 2, buffer, 0, 2); + Assert.That(buffer, Is.EqualTo(new []{'2', '3'})); + reader.GetChars(0, 4, buffer, 0, 2); + Assert.That(buffer, Is.EqualTo(new []{'4', '5'})); + reader.GetChars(0, 6, buffer, 0, 2); + Assert.That(buffer, Is.EqualTo(new []{'6', '7'})); + reader.GetChars(0, 7, buffer, 0, 2); + Assert.That(buffer, Is.EqualTo(new []{'7', '7'})); + + reader.GetChars(0, 4, buffer, 0, 2); + Assert.That(buffer, Is.EqualTo(new []{'4', '5'})); + reader.GetChars(0, 6, buffer, 0, 2); + Assert.That(buffer, Is.EqualTo(new []{'6', '7'})); + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task GetTextReader([Values(true, false)] bool isAsync) + { + const string str = "ABCDE"; + var queryText = $@"SELECT '{str}', 'foo'"; + Fixture.SpannerMock.AddOrUpdateStatementResult(queryText, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "c"), Tuple.Create(TypeCode.String, "c")], [[str, "foo"]])); + + await using var conn = await OpenConnectionAsync(); + var expected = str.ToCharArray(); + var actual = new char[expected.Length]; + + await using var cmd = new SpannerCommand(queryText, conn); + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + var textReader = reader.GetTextReader(0); + textReader.Read(actual, 0, 2); + Assert.That(actual[0], Is.EqualTo(expected[0])); + Assert.That(actual[1], Is.EqualTo(expected[1])); + Assert.That(() => reader.GetTextReader(0), Throws.Nothing, "Sequential text reader twice on same column"); + textReader.Read(actual, 2, 1); + Assert.That(actual[2], Is.EqualTo(expected[2])); + textReader.Dispose(); + + Assert.That(reader.GetChars(0, 0, actual, 4, 1), Is.EqualTo(1)); + Assert.That(actual[4], Is.EqualTo(expected[0])); + Assert.That(reader.GetString(1), Is.EqualTo("foo")); + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + public async Task TextReaderZeroLengthColumn() + { + const string sql = "SELECT ''"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "c")], [[""]])); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = conn.CreateCommand(); + cmd.CommandText = sql; + + await using var reader = await cmd.ExecuteReaderAsync(); + Assert.That(await reader.ReadAsync()); + + using var textReader = reader.GetTextReader(0); + Assert.That(textReader.Peek(), Is.EqualTo(-1)); + Assert.That(textReader.ReadToEnd(), Is.EqualTo(string.Empty)); + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + public async Task OpenTextReaderWhenChangingColumns() + { + const string sql = "SELECT 'some_text', 'some_text'"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "c"), Tuple.Create(TypeCode.String, "c")], [["some_text", "some-text"]])); + + using var conn = await OpenConnectionAsync(); + using var cmd = new SpannerCommand(sql, conn); + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + var textReader = reader.GetTextReader(0); + _ = reader.GetValue(1); + Assert.That(() => textReader.Peek(), Throws.Nothing); + } + + [Test] + [SuppressMessage("ReSharper", "MethodHasAsyncOverload")] + [SuppressMessage("ReSharper", "UseAwaitUsing")] + public async Task OpenTextReaderWhenChangingRows() + { + const string sql = "SELECT 'some_text', 'some_text'"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "c"), Tuple.Create(TypeCode.String, "c")], [["some_text", "some-text"]])); + + using var conn = await OpenConnectionAsync(); + using var cmd = new SpannerCommand(sql, conn); + using var reader = await cmd.ExecuteReaderAsync(); + reader.Read(); + var tr1 = reader.GetTextReader(0); + reader.Read(); + Assert.That(() => tr1.Peek(), Throws.Nothing); + } + + [Test] + public async Task GetCharsWhenNull() + { + const string sql = "SELECT cast(null as string)"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "c")], [[DBNull.Value]])); + + var buf = new char[8]; + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(sql, conn); + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + Assert.That(reader.IsDBNull(0), Is.True); + Assert.That(reader.GetChars(0, 0, buf, 0, 1), Is.EqualTo(0)); + Assert.That(() => reader.GetTextReader(0), Throws.Nothing, "GetTextReader"); + Assert.That(reader.GetChars(0, 0, null, 0, 0), Is.EqualTo(0), "GetChars with null buffer"); + } + + [Test] + public async Task GetTextReaderAfterConsumingColumnWorks() + { + const string sql = "SELECT 'foo'"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "c")], [["foo"]])); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(sql, conn); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); + await reader.ReadAsync(); + + _ = reader.GetString(0); + Assert.That(() => reader.GetTextReader(0), Throws.Nothing); + } + + [Test] + public async Task GetTextReaderInMiddleOfColumnWorks() + { + const string sql = "SELECT 'foo'"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + [Tuple.Create(TypeCode.String, "c")], [["foo"]])); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(sql, conn); + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); + await reader.ReadAsync(); + + _ = reader.GetChars(0, 0, new char[2], 0, 2); + Assert.That(() => reader.GetTextReader(0), Throws.Nothing); + } + +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-tests/SchemaTests.cs b/drivers/spanner-ado-net/spanner-ado-net-tests/SchemaTests.cs new file mode 100644 index 00000000..8f3b0661 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-tests/SchemaTests.cs @@ -0,0 +1,251 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Data; +using System.Data.Common; +using System.Text.RegularExpressions; +using NUnit.Framework.Internal; +using TypeCode = Google.Cloud.Spanner.V1.TypeCode; + +namespace Google.Cloud.Spanner.DataProvider.Tests; + +public class SchemaTests : AbstractMockServerTests +{ + [Test] + public async Task MetaDataCollections() + { + await using var conn = await OpenConnectionAsync(); + + var metaDataCollections = await conn.GetSchemaAsync(DbMetaDataCollectionNames.MetaDataCollections); + Assert.That(metaDataCollections.Rows, Has.Count.GreaterThan(0)); + + foreach (var row in metaDataCollections.Rows.OfType()) + { + var collectionName = (string)row!["CollectionName"]; + Assert.That(await conn.GetSchemaAsync(collectionName), Is.Not.Null, $"Collection {collectionName} advertise in MetaDataCollections but is null"); + } + } + + [Test] + public async Task NoParameter() + { + await using var conn = await OpenConnectionAsync(); + + var dataTable1 = conn.GetSchema(); + var collections1 = dataTable1.Rows + .Cast() + .Select(r => (string)r["CollectionName"]) + .ToList(); + + var dataTable2 = conn.GetSchema(DbMetaDataCollectionNames.MetaDataCollections); + var collections2 = dataTable2.Rows + .Cast() + .Select(r => (string)r["CollectionName"]) + .ToList(); + + Assert.That(collections1, Is.EquivalentTo(collections2)); + } + + [Test] + public async Task CaseInsensitiveCollectionName() + { + await using var conn = await OpenConnectionAsync(); + + var dataTable1 = conn.GetSchema(DbMetaDataCollectionNames.MetaDataCollections); + var collections1 = dataTable1.Rows + .Cast() + .Select(r => (string)r["CollectionName"]) + .ToList(); + + var dataTable2 = conn.GetSchema("METADATACOLLECTIONS"); + var collections2 = dataTable2.Rows + .Cast() + .Select(r => (string)r["CollectionName"]) + .ToList(); + + var dataTable3 = conn.GetSchema("metadatacollections"); + var collections3 = dataTable3.Rows + .Cast() + .Select(r => (string)r["CollectionName"]) + .ToList(); + + var dataTable4 = conn.GetSchema("MetaDataCollections"); + var collections4 = dataTable4.Rows + .Cast() + .Select(r => (string)r["CollectionName"]) + .ToList(); + + var dataTable5 = conn.GetSchema("METADATACOLLECTIONS", null!); + var collections5 = dataTable5.Rows + .Cast() + .Select(r => (string)r["CollectionName"]) + .ToList(); + + var dataTable6 = conn.GetSchema("metadatacollections", null!); + var collections6 = dataTable6.Rows + .Cast() + .Select(r => (string)r["CollectionName"]) + .ToList(); + + var dataTable7 = conn.GetSchema("MetaDataCollections", null!); + var collections7 = dataTable7.Rows + .Cast() + .Select(r => (string)r["CollectionName"]) + .ToList(); + + Assert.That(collections1, Is.EquivalentTo(collections2)); + Assert.That(collections1, Is.EquivalentTo(collections3)); + Assert.That(collections1, Is.EquivalentTo(collections4)); + Assert.That(collections1, Is.EquivalentTo(collections5)); + Assert.That(collections1, Is.EquivalentTo(collections6)); + Assert.That(collections1, Is.EquivalentTo(collections7)); + } + + [Test] + public async Task DataSourceInformation() + { + await using var conn = await OpenConnectionAsync(); + var dataTable = conn.GetSchema(DbMetaDataCollectionNames.MetaDataCollections); + var metadata = dataTable.Rows + .Cast() + .Single(r => r["CollectionName"].Equals("DataSourceInformation")); + Assert.That(metadata["NumberOfRestrictions"], Is.Zero); + Assert.That(metadata["NumberOfIdentifierParts"], Is.Zero); + + var dataSourceInfo = conn.GetSchema(DbMetaDataCollectionNames.DataSourceInformation); + var row = dataSourceInfo.Rows.Cast().Single(); + + Assert.That(row["DataSourceProductName"], Is.EqualTo("Spanner")); + Assert.That(row["DataSourceProductVersion"], Is.EqualTo("1.0.0")); + Assert.That(row["DataSourceProductVersionNormalized"], Is.EqualTo("001.000.0000")); + + Assert.That(Regex.Match("`some_identifier`", (string)row["QuotedIdentifierPattern"]).Groups[1].Value, + Is.EqualTo("some_identifier")); + } + + [Test] + public async Task DataTypes() + { + await using var connection = await OpenConnectionAsync(); + + var dataTable = connection.GetSchema(DbMetaDataCollectionNames.MetaDataCollections); + var metadata = dataTable.Rows + .Cast() + .Single(r => r["CollectionName"].Equals("DataTypes")); + Assert.That(metadata["NumberOfRestrictions"], Is.Zero); + Assert.That(metadata["NumberOfIdentifierParts"], Is.Zero); + + var dataTypes = connection.GetSchema(DbMetaDataCollectionNames.DataTypes); + + var boolRow = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("Bool")); + Assert.That(boolRow["DataType"], Is.EqualTo("System.Boolean")); + Assert.That(boolRow["ProviderDbType"], Is.EqualTo((int)TypeCode.Bool)); + Assert.That(boolRow["IsUnsigned"], Is.EqualTo(DBNull.Value)); + + var bytesRow = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("Bytes")); + Assert.That(bytesRow["DataType"], Is.EqualTo("System.Byte[]")); + Assert.That(bytesRow["ProviderDbType"], Is.EqualTo((int)TypeCode.Bytes)); + Assert.That(bytesRow["IsUnsigned"], Is.EqualTo(DBNull.Value)); + Assert.That(bytesRow["IsBestMatch"], Is.True); + + var dateRow = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("Date")); + Assert.That(dateRow["DataType"], Is.EqualTo("System.DateOnly")); + Assert.That(dateRow["ProviderDbType"], Is.EqualTo((int)TypeCode.Date)); + Assert.That(dateRow["IsUnsigned"], Is.EqualTo(DBNull.Value)); + Assert.That(dateRow["IsBestMatch"], Is.True); + + var enumRow = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("Enum")); + Assert.That(enumRow["DataType"], Is.EqualTo("System.Int64")); + Assert.That(enumRow["ProviderDbType"], Is.EqualTo((int)TypeCode.Enum)); + Assert.That(enumRow["IsUnsigned"], Is.EqualTo(DBNull.Value)); + Assert.That(enumRow["IsBestMatch"], Is.False); + + var float32Row = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("Float32")); + Assert.That(float32Row["DataType"], Is.EqualTo("System.Single")); + Assert.That(float32Row["ProviderDbType"], Is.EqualTo((int)TypeCode.Float32)); + Assert.That(float32Row["IsUnsigned"], Is.False); + Assert.That(float32Row["IsBestMatch"], Is.True); + + var float64Row = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("Float64")); + Assert.That(float64Row["DataType"], Is.EqualTo("System.Double")); + Assert.That(float64Row["ProviderDbType"], Is.EqualTo((int)TypeCode.Float64)); + Assert.That(float64Row["IsUnsigned"], Is.False); + Assert.That(float64Row["IsBestMatch"], Is.True); + + var int64Row = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("Int64")); + Assert.That(int64Row["DataType"], Is.EqualTo("System.Int64")); + Assert.That(int64Row["ProviderDbType"], Is.EqualTo((int)TypeCode.Int64)); + Assert.That(int64Row["IsUnsigned"], Is.False); + Assert.That(int64Row["IsBestMatch"], Is.True); + + var intervalRow = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("Interval")); + Assert.That(intervalRow["DataType"], Is.EqualTo("System.TimeSpan")); + Assert.That(intervalRow["ProviderDbType"], Is.EqualTo((int)TypeCode.Interval)); + Assert.That(intervalRow["IsUnsigned"], Is.EqualTo(DBNull.Value)); + Assert.That(intervalRow["IsBestMatch"], Is.True); + + var jsonRow = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("Json")); + Assert.That(jsonRow["DataType"], Is.EqualTo("System.String")); + Assert.That(jsonRow["ProviderDbType"], Is.EqualTo((int)TypeCode.Json)); + Assert.That(jsonRow["IsUnsigned"], Is.EqualTo(DBNull.Value)); + Assert.That(jsonRow["IsBestMatch"], Is.False); + + var numericRow = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("Numeric")); + Assert.That(numericRow["DataType"], Is.EqualTo("System.Decimal")); + Assert.That(numericRow["ProviderDbType"], Is.EqualTo((int)TypeCode.Numeric)); + Assert.That(numericRow["IsUnsigned"], Is.False); + Assert.That(numericRow["IsBestMatch"], Is.True); + + var protoRow = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("Proto")); + Assert.That(protoRow["DataType"], Is.EqualTo("System.Byte[]")); + Assert.That(protoRow["ProviderDbType"], Is.EqualTo((int)TypeCode.Proto)); + Assert.That(protoRow["IsUnsigned"], Is.EqualTo(DBNull.Value)); + Assert.That(protoRow["IsBestMatch"], Is.False); + + var stringRow = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("String")); + Assert.That(stringRow["DataType"], Is.EqualTo("System.String")); + Assert.That(stringRow["ProviderDbType"], Is.EqualTo((int)TypeCode.String)); + Assert.That(stringRow["IsUnsigned"], Is.EqualTo(DBNull.Value)); + Assert.That(stringRow["IsBestMatch"], Is.True); + + var timestampRow = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("Timestamp")); + Assert.That(timestampRow["DataType"], Is.EqualTo("System.DateTime")); + Assert.That(timestampRow["ProviderDbType"], Is.EqualTo((int)TypeCode.Timestamp)); + Assert.That(timestampRow["IsUnsigned"], Is.EqualTo(DBNull.Value)); + Assert.That(timestampRow["IsBestMatch"], Is.True); + + var uuidRow = dataTypes.Rows.Cast().Single(r => r["TypeName"].Equals("Uuid")); + Assert.That(uuidRow["DataType"], Is.EqualTo("System.Guid")); + Assert.That(uuidRow["ProviderDbType"], Is.EqualTo((int)TypeCode.Uuid)); + Assert.That(uuidRow["IsUnsigned"], Is.EqualTo(DBNull.Value)); + Assert.That(uuidRow["IsBestMatch"], Is.True); + } + + [Test] + public async Task Restrictions() + { + await using var conn = await OpenConnectionAsync(); + var restrictions = conn.GetSchema(DbMetaDataCollectionNames.Restrictions); + Assert.That(restrictions.Rows, Has.Count.GreaterThan(0)); + } + + [Test] + public async Task ReservedWords() + { + await using var conn = await OpenConnectionAsync(); + var reservedWords = conn.GetSchema(DbMetaDataCollectionNames.ReservedWords); + Assert.That(reservedWords.Rows, Has.Count.GreaterThan(0)); + } + +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-tests/SpannerParameterCollectionTests.cs b/drivers/spanner-ado-net/spanner-ado-net-tests/SpannerParameterCollectionTests.cs new file mode 100644 index 00000000..0ce010a0 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-tests/SpannerParameterCollectionTests.cs @@ -0,0 +1,317 @@ +using System.Data; +using System.Data.Common; +using System.Diagnostics.CodeAnalysis; + +namespace Google.Cloud.Spanner.DataProvider.Tests; + +public class SpannerParameterCollectionTests : AbstractMockServerTests +{ + [Test] + public void CanOnlyAddSpannerParameterOrValidValue() + { + using var command = new SpannerCommand(); + + Assert.DoesNotThrow(() => command.Parameters.Add("hello")); + + Assert.That(() => command.Parameters.Add(new SomeOtherDbParameter()), Throws.Exception.TypeOf()); + Assert.That(() => command.Parameters.Add(null!), Throws.Exception.TypeOf()); + } + + [Test] + public void Clear() + { + var p = new SpannerParameter(); + var c1 = new SpannerCommand(); + var c2 = new SpannerCommand(); + c1.Parameters.Add(p); + Assert.That(c1.Parameters.Count, Is.EqualTo(1)); + Assert.That(c2.Parameters.Count, Is.EqualTo(0)); + c1.Parameters.Clear(); + Assert.That(c1.Parameters.Count, Is.EqualTo(0)); + c2.Parameters.Add(p); + Assert.That(c1.Parameters.Count, Is.EqualTo(0)); + Assert.That(c2.Parameters.Count, Is.EqualTo(1)); + } + + [Test] + public void ParameterRename() + { + using var command = new SpannerCommand(); + for (var i = 0; i < 10; i++) + { + command.AddParameter($"p{i + 1:00}", $"String parameter value {i + 1}"); + } + Assert.That(command.Parameters["p03"].ParameterName, Is.EqualTo("p03")); + + // Rename a parameter. + command.Parameters["p03"].ParameterName = "a_new_name"; + Assert.That(command.Parameters.IndexOf("a_new_name"), Is.GreaterThanOrEqualTo(0)); + } + + [Test] + public void UnnamedParameterRename() + { + using var command = new SpannerCommand(); + + for (var i = 0; i < 3; i++) + { + for (var j = 0; j < 10; j++) + { + // Create and add an unnamed parameter before renaming it + var parameter = command.CreateParameter(); + command.Parameters.Add(parameter); + parameter.ParameterName = $"{j}"; + } + Assert.That(command.Parameters["3"].ParameterName, Is.EqualTo("3")); + command.Parameters.Clear(); + } + } + + [Test] + public void RemoveDuplicateParameter() + { + using var command = new SpannerCommand(); + var count = 10; + for (var i = 0; i < count; i++) + { + command.AddParameter($"p{i + 1:00}", $"String parameter value {i + 1}"); + } + + Assert.That(command.Parameters["p02"].ParameterName, Is.EqualTo("p02")); + // Add uppercased version of the same parameter. + command.AddParameter("P02", "String parameter value 2"); + // Remove the original parameter by its name. + command.Parameters.Remove(command.Parameters["p02"]); + + // Test whether we can still find the last added parameter, and if its index is correctly shifted in the lookup. + Assert.That(command.Parameters.IndexOf("p02"), Is.EqualTo(count - 1)); + Assert.That(command.Parameters.IndexOf("P02"), Is.EqualTo(count - 1)); + // And finally test whether other parameters were also correctly shifted. + Assert.That(command.Parameters.IndexOf("p03"), Is.EqualTo(1)); + Assert.That(command.Parameters.IndexOf("p03") == 1); + } + + [Test] + public void RemoveParameter() + { + using var command = new SpannerCommand(); + var count = 10; + for (var i = 0; i < count; i++) + { + command.AddParameter($"p{i + 1:00}", $"String parameter value {i + 1}"); + } + + // Remove the parameter by its name + command.Parameters.Remove(command.Parameters["p02"]); + + // Make sure we cannot find it, also not case insensitively. + Assert.That(command.Parameters.IndexOf("p02"), Is.EqualTo(-1)); + Assert.That(command.Parameters.IndexOf("P02"), Is.EqualTo(-1)); + } + + [Test] + public void RemoveCaseDifferingParameter() + { + var count = 10; + // Add two parameters that only differ in casing. + using var command = new SpannerCommand(); + command.AddParameter("PP0", 1); + command.AddParameter("Pp0", 1); + for (var i = 0; i < count - 2; i++) + { + command.AddParameter($"pp{i}", i); + } + + // Removing Pp0. + command.Parameters.RemoveAt(1); + + // Matching on parameter name always first prefers case-sensitive matching, so we match entry 1 ('pp0'). + Assert.That(command.Parameters.IndexOf("pp0"), Is.EqualTo(1)); + // Exact match to PP0. + Assert.That(command.Parameters.IndexOf("PP0"), Is.EqualTo(0)); + // Case-insensitive match to PP0. + Assert.That(command.Parameters.IndexOf("Pp0"), Is.EqualTo(0)); + } + + [Test] + public void CorrectIndexReturnedForDuplicateParameterName() + { + const int count = 10; + using var command = new SpannerCommand(); + for (var i = 0; i < count; i++) + { + command.AddParameter($"parameter{i + 1:00}", $"String parameter value {i + 1}"); + } + Assert.That(command.Parameters["parameter02"].ParameterName, Is.EqualTo("parameter02")); + + // Add an upper-case version of one of the parameters. + command.AddParameter("Parameter02", "String parameter value 2"); + + // Insert another case-insensitive before the original. + command.Parameters.Insert(0, new SpannerParameter { ParameterName = "ParameteR02", Value = "String parameter value 2" }); + + // Try to find the exact index. + Assert.That(command.Parameters.IndexOf("parameter02"), Is.EqualTo(2)); + Assert.That(command.Parameters.IndexOf("Parameter02"), Is.EqualTo(command.Parameters.Count - 1)); + Assert.That(command.Parameters.IndexOf("ParameteR02"), Is.EqualTo(0)); + // This name does not exist so we expect the first case-insensitive match to be returned. + Assert.That(command.Parameters.IndexOf("ParaMeteR02"), Is.EqualTo(0)); + + // And finally test whether other parameters were also correctly shifted. + Assert.That(command.Parameters.IndexOf("parameter03"), Is.EqualTo(3)); + } + + [Test] + public void FindsCaseInsensitiveLookups() + { + const int count = 10; + using var command = new SpannerCommand(); + var parameters = command.Parameters; + for (var i = 0; i < count; i++) + { + parameters.Add(new SpannerParameter{ ParameterName = $"p{i}", Value = i }); + } + Assert.That(command.Parameters.IndexOf("P1"), Is.EqualTo(1)); + } + + [Test] + public void FindsCaseSensitiveLookups() + { + const int count = 10; + using var command = new SpannerCommand(); + var parameters = command.Parameters; + for (var i = 0; i < count; i++) + { + parameters.Add(new SpannerParameter{ ParameterName = $"p{i}", Value = i}); + } + Assert.That(command.Parameters.IndexOf("p1"), Is.EqualTo(1)); + } + + [Test] + public void ThrowsOnIndexerMismatch() + { + const int count = 10; + using var command = new SpannerCommand(); + var parameters = command.Parameters; + for (var i = 0; i < count; i++) + { + parameters.Add(new SpannerParameter{ ParameterName = $"p{i}", Value = i}); + } + + Assert.DoesNotThrow(() => + { + command.Parameters["p1"] = new SpannerParameter("p1", 1); + command.Parameters["p1"] = new SpannerParameter("P1", 1); + }); + + Assert.Throws(() => + { + command.Parameters["p1"] = new SpannerParameter("p2", 1); + }); + } + + [Test] + public void PositionalParameterLookupReturnsFirstMatch() + { + const int count = 10; + using var command = new SpannerCommand(); + var parameters = command.Parameters; + for (var i = 0; i < count; i++) + { + parameters.Add(new SpannerParameter("", i)); + } + Assert.That(command.Parameters.IndexOf(""), Is.EqualTo(0)); + } + + [Test] + public void MultiplePositionsSameInstanceIsAllowed() + { + using var cmd = new SpannerCommand(); + cmd.CommandText = "SELECT $1, $2"; + var p = new SpannerParameter("", "Hello world"); + cmd.Parameters.Add(p); + Assert.DoesNotThrow(() => cmd.Parameters.Add(p)); + } + + [Test] + public void IndexOfFallsBackToFirstInsensitiveMatch([Values] bool manyParams) + { + const int count = 10; + using var command = new SpannerCommand(); + var parameters = command.Parameters; + + parameters.Add(new SpannerParameter("foo", 8)); + parameters.Add(new SpannerParameter("bar", 8)); + parameters.Add(new SpannerParameter("BAR", 8)); + + if (manyParams) + { + for (var i = 0; i < count; i++) + { + parameters.Add(new SpannerParameter($"p{i}", i)); + } + } + Assert.That(parameters.IndexOf("Bar"), Is.EqualTo(1)); + } + + [Test] + public void IndexOfPrefersCaseSensitiveMatch([Values] bool manyParams) + { + const int count = 10; + using var command = new SpannerCommand(); + var parameters = command.Parameters; + + parameters.Add(new SpannerParameter("FOO", 8)); + parameters.Add(new SpannerParameter("foo", 8)); + + if (manyParams) + { + for (var i = 0; i < count; i++) + { + parameters.Add(new SpannerParameter($"p{i}", i)); + } + } + Assert.That(parameters.IndexOf("foo"), Is.EqualTo(1)); + } + + [Test] + public void CloningSucceeds() + { + const int count = 10; + var command = new SpannerCommand(); + for (var i = 0; i < count; i++) + { + command.Parameters.Add(new SpannerParameter()); + } + Assert.DoesNotThrow(() => command.Clone()); + } + + [Test] + public void CleanName() + { + var param = new SpannerParameter(); + var command = new SpannerCommand(); + command.Parameters.Add(param); + + param.ParameterName = null; + + // These should not throw exceptions + Assert.That(command.Parameters.IndexOf(param.ParameterName), Is.EqualTo(0)); + Assert.That(param.ParameterName, Is.EqualTo("")); + } + + class SomeOtherDbParameter : DbParameter + { + public override void ResetDbType() {} + + public override DbType DbType { get; set; } + public override ParameterDirection Direction { get; set; } + public override bool IsNullable { get; set; } + [AllowNull] public override string ParameterName { get; set; } = ""; + [AllowNull] public override string SourceColumn { get; set; } = ""; + public override object? Value { get; set; } + public override bool SourceColumnNullMapping { get; set; } + public override int Size { get; set; } + } + +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-tests/SpannerParameterTests.cs b/drivers/spanner-ado-net/spanner-ado-net-tests/SpannerParameterTests.cs new file mode 100644 index 00000000..cf702c73 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-tests/SpannerParameterTests.cs @@ -0,0 +1,214 @@ +using System.Data; +using System.Data.Common; +using Google.Cloud.Spanner.V1; +using Google.Cloud.SpannerLib.MockServer; +using TypeCode = Google.Cloud.Spanner.V1.TypeCode; + +namespace Google.Cloud.Spanner.DataProvider.Tests; + +public class SpannerParameterTests : AbstractMockServerTests +{ + [Test] + public void SettingValueDoesNotChangeDbType() + { + // ReSharper disable once UseObjectOrCollectionInitializer + var p = new SpannerParameter { DbType = DbType.String }; + p.Value = 8; + Assert.That(p.DbType, Is.EqualTo(DbType.String)); + } + + [Test] + public void DefaultConstructor() + { + var p = new SpannerParameter(); + Assert.That(p.DbType, Is.EqualTo(DbType.String), "DbType"); + Assert.That(p.Direction, Is.EqualTo(ParameterDirection.Input), "Direction"); + Assert.That(p.IsNullable, Is.False, "IsNullable"); + Assert.That(p.ParameterName, Is.Empty, "ParameterName"); + Assert.That(p.Precision, Is.EqualTo(0), "Precision"); + Assert.That(p.Scale, Is.EqualTo(0), "Scale"); + Assert.That(p.Size, Is.EqualTo(0), "Size"); + Assert.That(p.SourceColumn, Is.Empty, "SourceColumn"); + Assert.That(p.SourceVersion, Is.EqualTo(DataRowVersion.Current), "SourceVersion"); + Assert.That(p.Value, Is.Null, "Value"); + } + + [Test] + public void ConstructorValueDateTime() + { + var value = new DateTime(2004, 8, 24); + + var p = new SpannerParameter("address", value); + // Setting a parameter value does not change the type. + Assert.That(p.DbType, Is.EqualTo(DbType.String), "B:DbType"); + Assert.That(p.Direction, Is.EqualTo(ParameterDirection.Input), "B:Direction"); + Assert.That(p.IsNullable, Is.False, "B:IsNullable"); + Assert.That(p.ParameterName, Is.EqualTo("address"), "B:ParameterName"); + Assert.That(p.Precision, Is.EqualTo(0), "B:Precision"); + Assert.That(p.Scale, Is.EqualTo(0), "B:Scale"); + Assert.That(p.Size, Is.EqualTo(0), "B:Size"); + Assert.That(p.SourceColumn, Is.Empty, "B:SourceColumn"); + Assert.That(p.SourceVersion, Is.EqualTo(DataRowVersion.Current), "B:SourceVersion"); + Assert.That(p.Value, Is.EqualTo(value), "B:Value"); + } + + [Test] + public void ConstructorValueDbNull() + { + var p = new SpannerParameter("address", DBNull.Value); + Assert.That(p.DbType, Is.EqualTo(DbType.String), "B:DbType"); + Assert.That(p.Direction, Is.EqualTo(ParameterDirection.Input), "B:Direction"); + Assert.That(p.IsNullable, Is.False, "B:IsNullable"); + Assert.That(p.ParameterName, Is.EqualTo("address"), "B:ParameterName"); + Assert.That(p.Precision, Is.EqualTo(0), "B:Precision"); + Assert.That(p.Scale, Is.EqualTo(0), "B:Scale"); + Assert.That(p.Size, Is.EqualTo(0), "B:Size"); + Assert.That(p.SourceColumn, Is.Empty, "B:SourceColumn"); + Assert.That(p.SourceVersion, Is.EqualTo(DataRowVersion.Current), "B:SourceVersion"); + Assert.That(p.Value, Is.EqualTo(DBNull.Value), "B:Value"); + } + + [Test] + public void ConstructorValueNull() + { + var p = new SpannerParameter("address", null); + Assert.That(p.DbType, Is.EqualTo(DbType.String), "A:DbType"); + Assert.That(p.Direction, Is.EqualTo(ParameterDirection.Input), "A:Direction"); + Assert.That(p.IsNullable, Is.False, "A:IsNullable"); + Assert.That(p.ParameterName, Is.EqualTo("address"), "A:ParameterName"); + Assert.That(p.Precision, Is.EqualTo(0), "A:Precision"); + Assert.That(p.Scale, Is.EqualTo(0), "A:Scale"); + Assert.That(p.Size, Is.EqualTo(0), "A:Size"); + Assert.That(p.SourceColumn, Is.Empty, "A:SourceColumn"); + Assert.That(p.SourceVersion, Is.EqualTo(DataRowVersion.Current), "A:SourceVersion"); + Assert.That(p.Value, Is.Null, "A:Value"); + } + + [Test] + public void Clone() + { + var expected = new SpannerParameter + { + Value = 42, + ParameterName = "TheAnswer", + + DbType = DbType.Int32, + + Direction = ParameterDirection.InputOutput, + IsNullable = true, + Precision = 1, + Scale = 2, + Size = 4, + + SourceVersion = DataRowVersion.Proposed, + SourceColumn = "source", + SourceColumnNullMapping = true, + }; + var actual = expected.Clone(); + + Assert.That(actual.Value, Is.EqualTo(expected.Value)); + Assert.That(actual.ParameterName, Is.EqualTo(expected.ParameterName)); + + Assert.That(actual.DbType, Is.EqualTo(expected.DbType)); + + Assert.That(actual.Direction, Is.EqualTo(expected.Direction)); + Assert.That(actual.IsNullable, Is.EqualTo(expected.IsNullable)); + Assert.That(actual.Precision, Is.EqualTo(expected.Precision)); + Assert.That(actual.Scale, Is.EqualTo(expected.Scale)); + Assert.That(actual.Size, Is.EqualTo(expected.Size)); + + Assert.That(actual.SourceVersion, Is.EqualTo(expected.SourceVersion)); + Assert.That(actual.SourceColumn, Is.EqualTo(expected.SourceColumn)); + Assert.That(actual.SourceColumnNullMapping, Is.EqualTo(expected.SourceColumnNullMapping)); + } + + [Test] + public void ParameterNull() + { + var param = new SpannerParameter{ParameterName = "param", DbType = DbType.Decimal}; + Assert.That(param.Scale, Is.EqualTo(0), "#A1"); + param.Value = DBNull.Value; + Assert.That(param.Scale, Is.EqualTo(0), "#A2"); + + param = new SpannerParameter{ParameterName = "param", DbType = DbType.Int32}; + Assert.That(param.Scale, Is.EqualTo(0), "#B1"); + param.Value = DBNull.Value; + Assert.That(param.Scale, Is.EqualTo(0), "#B2"); + } + + [Test] + public async Task MatchParamIndexCaseInsensitively() + { + const string sql = "SELECT @p,@P"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateResultSet( + new List>([Tuple.Create(TypeCode.String, "p"), Tuple.Create(TypeCode.String, "p")]), + new List([["Hello World", "Hello World"]]))); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = new SpannerCommand(sql, conn); + cmd.AddParameter("p", "Hello World"); + await cmd.ExecuteNonQueryAsync(); + + var request = Fixture.SpannerMock.Requests.OfType().Single(r => r.Sql == sql); + Assert.That(request, Is.Not.Null); + // TODO: Revisit once https://github.com/googleapis/go-sql-spanner/issues/594 has been decided. + Assert.That(request.Params.Fields.Count, Is.EqualTo(2)); + Assert.That(request.Params.Fields["p"].StringValue, Is.EqualTo("Hello World")); + Assert.That(request.Params.Fields["P"].HasNullValue); + } + + [Test] + public void PrecisionViaInterface() + { + var parameter = new SpannerParameter(); + var paramIface = (IDbDataParameter)parameter; + + paramIface.Precision = 42; + + Assert.That(paramIface.Precision, Is.EqualTo((byte)42)); + } + + [Test] + public void PrecisionViaBaseClass() + { + var parameter = new SpannerParameter(); + var paramBase = (DbParameter)parameter; + + paramBase.Precision = 42; + + Assert.That(paramBase.Precision, Is.EqualTo((byte)42)); + } + + [Test] + public void ScaleViaInterface() + { + var parameter = new SpannerParameter(); + var paramIface = (IDbDataParameter)parameter; + + paramIface.Scale = 42; + + Assert.That(paramIface.Scale, Is.EqualTo((byte)42)); + } + + [Test] + public void ScaleViaBaseClass() + { + var parameter = new SpannerParameter(); + var paramBase = (DbParameter)parameter; + + paramBase.Scale = 42; + + Assert.That(paramBase.Scale, Is.EqualTo((byte)42)); + } + + [Test] + public void NullValueThrows() + { + using var connection = OpenConnection(); + using var command = new SpannerCommand("SELECT @p", connection); + command.Parameters.Add(new SpannerParameter("p", null)); + + Assert.That(() => command.ExecuteReader(), Throws.InvalidOperationException); + } + +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-tests/TestUtils.cs b/drivers/spanner-ado-net/spanner-ado-net-tests/TestUtils.cs new file mode 100644 index 00000000..24f9b859 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-tests/TestUtils.cs @@ -0,0 +1,19 @@ +namespace Google.Cloud.Spanner.DataProvider.Tests; + +public static class TestUtils +{ + private const string Chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + + public static string GenerateRandomString(int length) + { + return new string(Enumerable.Repeat(Chars, length) + .Select(s => s[Random.Shared.Next(s.Length)]).ToArray()); + } + +} + +public enum PrepareOrNot +{ + Prepared, + NotPrepared +} diff --git a/drivers/spanner-ado-net/spanner-ado-net-tests/TransactionTests.cs b/drivers/spanner-ado-net/spanner-ado-net-tests/TransactionTests.cs new file mode 100644 index 00000000..3a0353d1 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-tests/TransactionTests.cs @@ -0,0 +1,182 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using Google.Cloud.Spanner.V1; +using Google.Cloud.SpannerLib.MockServer; + +namespace Google.Cloud.Spanner.DataProvider.Tests; + +public class TransactionTests : AbstractMockServerTests +{ + [Test] + public async Task TestReadWriteTransaction() + { + const string sql = "update my_table set my_column=@value where id=@id"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateUpdateCount(1L)); + + await using var connection = new SpannerConnection(); + connection.ConnectionString = ConnectionString; + await connection.OpenAsync(); + await using var transaction = await connection.BeginTransactionAsync(); + await using var command = connection.CreateCommand(); + command.CommandText = sql; + var paramId = command.CreateParameter(); + paramId.ParameterName = "id"; + paramId.Value = 1; + command.Parameters.Add(paramId); + var paramValue = command.CreateParameter(); + paramValue.ParameterName = "value"; + paramValue.Value = "One"; + command.Parameters.Add(paramValue); + var updateCount = await command.ExecuteNonQueryAsync(); + await transaction.CommitAsync(); + + Assert.That(updateCount, Is.EqualTo(1)); + var requests = Fixture.SpannerMock.Requests.ToList(); + // The transaction should use inline-begin. + Assert.That(requests.OfType().Count(), Is.EqualTo(0)); + Assert.That(requests.OfType().Count(), Is.EqualTo(1)); + Assert.That(requests.OfType().Count(), Is.EqualTo(1)); + var executeRequest = requests.OfType().First(); + Assert.That(executeRequest.Transaction, Is.EqualTo(new TransactionSelector + { + Begin = new TransactionOptions + { + ReadWrite = new TransactionOptions.Types.ReadWrite(), + } + })); + } + + [Test] + public async Task TestReadOnlyTransaction() + { + const string sql = "select value from my_table where id=@id"; + Fixture.SpannerMock.AddOrUpdateStatementResult(sql, StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = V1.TypeCode.String}, "value", "One")); + + await using var connection = new SpannerConnection(); + connection.ConnectionString = ConnectionString; + await connection.OpenAsync(); + await using var transaction = connection.BeginReadOnlyTransaction(); + await using var command = connection.CreateCommand(); + command.CommandText = sql; + var paramId = command.CreateParameter(); + paramId.ParameterName = "id"; + paramId.Value = 1; + command.Parameters.Add(paramId); + await using var reader = await command.ExecuteReaderAsync(); + Assert.That(await reader.ReadAsync()); + Assert.That(reader.FieldCount, Is.EqualTo(1)); + Assert.That(reader.GetValue(0), Is.EqualTo("One")); + Assert.That(await reader.ReadAsync(), Is.False); + + // We must commit the transaction in order to end it. + await transaction.CommitAsync(); + + var requests = Fixture.SpannerMock.Requests.ToList(); + // The transaction should use inline-begin. + Assert.That(requests.OfType().Count(), Is.EqualTo(0)); + Assert.That(requests.OfType().Count(), Is.EqualTo(1)); + // Committing a read-only transaction is a no-op on Spanner. + Assert.That(requests.OfType().Count(), Is.EqualTo(0)); + var executeRequest = requests.OfType().First(); + Assert.That(executeRequest.Transaction, Is.EqualTo(new TransactionSelector + { + Begin = new TransactionOptions + { + ReadOnly = new TransactionOptions.Types.ReadOnly + { + Strong = true, + ReturnReadTimestamp = true, + }, + } + })); + } + + [Ignore("Needs a fix in SpannerLib")] + [Test] + public async Task TestTransactionTag() + { + const string select = "select value from my_table where id=@id"; + Fixture.SpannerMock.AddOrUpdateStatementResult(select, StatementResult.CreateSingleColumnResultSet(new V1.Type{Code = V1.TypeCode.String}, "value", "one")); + const string update = "update my_table set my_column=@value where id=@id"; + Fixture.SpannerMock.AddOrUpdateStatementResult(update, StatementResult.CreateUpdateCount(1L)); + + await using var connection = new SpannerConnection(); + connection.ConnectionString = ConnectionString; + await connection.OpenAsync(); + await using var setTagCommand = connection.CreateCommand(); + setTagCommand.CommandText = "set transaction_tag='test_tag'"; + await setTagCommand.ExecuteNonQueryAsync(); + await using var transaction = await connection.BeginTransactionAsync(); + + await using var command = connection.CreateCommand(); + command.CommandText = select; + var selectParamId = command.CreateParameter(); + selectParamId.ParameterName = "id"; + selectParamId.Value = 1; + command.Parameters.Add(selectParamId); + await using var reader = await command.ExecuteReaderAsync(); + Assert.That(await reader.ReadAsync()); + Assert.That(reader.FieldCount, Is.EqualTo(1)); + Assert.That(reader.GetValue(0), Is.EqualTo("one")); + Assert.That(await reader.ReadAsync(), Is.False); + + await using var updateCommand = connection.CreateCommand(); + updateCommand.CommandText = update; + var paramId = updateCommand.CreateParameter(); + paramId.ParameterName = "id"; + paramId.Value = 1; + updateCommand.Parameters.Add(paramId); + var paramValue = updateCommand.CreateParameter(); + paramValue.ParameterName = "value"; + paramValue.Value = "One"; + updateCommand.Parameters.Add(paramValue); + var updateCount = await updateCommand.ExecuteNonQueryAsync(); + await transaction.CommitAsync(); + + Assert.That(updateCount, Is.EqualTo(1)); + var requests = Fixture.SpannerMock.Requests.ToList(); + // The transaction should use inline-begin. + Assert.That(requests.OfType().Count(), Is.EqualTo(0)); + Assert.That(requests.OfType().Count(), Is.EqualTo(2)); + Assert.That(requests.OfType().Count(), Is.EqualTo(1)); + var selectRequest = requests.OfType().First(); + Assert.That(selectRequest.Transaction, Is.EqualTo(new TransactionSelector + { + Begin = new TransactionOptions + { + ReadWrite = new TransactionOptions.Types.ReadWrite(), + } + })); + Assert.That(selectRequest.RequestOptions.TransactionTag, Is.EqualTo("test_tag")); + var updateRequest = requests.OfType().Single(request => request.Sql == update); + Assert.That(updateRequest.RequestOptions.TransactionTag, Is.EqualTo("test_tag")); + var commitRequest = requests.OfType().Single(); + Assert.That(commitRequest.RequestOptions.TransactionTag, Is.EqualTo("test_tag")); + + // The next transaction should not use the tag. + await using var tx2 = await connection.BeginTransactionAsync(); + await using var command2 = connection.CreateCommand(); + command2.CommandText = update; + command2.Parameters.Add(paramId); + command2.Parameters.Add(paramValue); + await command2.ExecuteNonQueryAsync(); + await tx2.CommitAsync(); + + var lastRequest = requests.OfType().Last(request => request.Sql == update); + Assert.That(lastRequest.RequestOptions.TransactionTag, Is.Null); + var lastCommitRequest = requests.OfType().Last(); + Assert.That(lastCommitRequest.RequestOptions.TransactionTag, Is.Null); + } +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net-tests/appsettings.json b/drivers/spanner-ado-net/spanner-ado-net-tests/appsettings.json new file mode 100644 index 00000000..2a8537d3 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-tests/appsettings.json @@ -0,0 +1,8 @@ +{ + "Logging": { + "IncludeScopes": false, + "LogLevel": { + "Microsoft": "Warning" + } + } +} diff --git a/drivers/spanner-ado-net/spanner-ado-net-tests/spanner-ado-net-tests.csproj b/drivers/spanner-ado-net/spanner-ado-net-tests/spanner-ado-net-tests.csproj new file mode 100644 index 00000000..31870103 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net-tests/spanner-ado-net-tests.csproj @@ -0,0 +1,35 @@ + + + + net8.0 + Google.Cloud.Spanner.DataProvider.Tests + enable + enable + + false + true + Google.Cloud.Spanner.DataProvider.Tests + default + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + + + + + + + + + diff --git a/drivers/spanner-ado-net/spanner-ado-net.sln b/drivers/spanner-ado-net/spanner-ado-net.sln new file mode 100644 index 00000000..7a24c1a4 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net.sln @@ -0,0 +1,70 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "spanner-ado-net", "spanner-ado-net\spanner-ado-net.csproj", "{C01E227F-E396-45E7-A82F-478EFA9AC0A6}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "spanner-ado-net-tests", "spanner-ado-net-tests\spanner-ado-net-tests.csproj", "{56052199-927F-46F5-8D0F-4826360E70B8}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "spanner-ado-net-specification-tests", "spanner-ado-net-specification-tests\spanner-ado-net-specification-tests.csproj", "{97D93DB7-CEB6-4C21-B4C2-A5A98D3FD59C}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "spanner-ado-net-samples", "spanner-ado-net-samples\spanner-ado-net-samples.csproj", "{537A257C-0228-418F-9DD5-A46324E591AE}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "spanner-ado-net-benchmarks", "spanner-ado-net-benchmarks\spanner-ado-net-benchmarks.csproj", "{2C70D969-A8AA-440B-81D8-532C327F237E}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "spannerlib-dotnet-mockserver", "..\..\spannerlib\wrappers\spannerlib-dotnet\spannerlib-dotnet-mockserver\spannerlib-dotnet-mockserver.csproj", "{E690FD52-65CD-4F11-A56E-A7D3B8D7A190}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "spannerlib-dotnet", "..\..\spannerlib\wrappers\spannerlib-dotnet\spannerlib-dotnet\spannerlib-dotnet.csproj", "{90663BC7-07FD-4089-9594-94D61D9F63A2}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "spannerlib-dotnet-grpc-impl", "..\..\spannerlib\wrappers\spannerlib-dotnet\spannerlib-dotnet-grpc-impl\spannerlib-dotnet-grpc-impl.csproj", "{8759AB44-DEC6-4E78-B64D-2EE4A403DFE1}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "spannerlib-dotnet-native-impl", "..\..\spannerlib\wrappers\spannerlib-dotnet\spannerlib-dotnet-native-impl\spannerlib-dotnet-native-impl.csproj", "{85711FA3-547A-4B8E-AA23-95A6108F0DF8}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "spannerlib-dotnet-grpc-v1", "..\..\spannerlib\wrappers\spannerlib-dotnet\spannerlib-dotnet-grpc-v1\spannerlib-dotnet-grpc-v1.csproj", "{DF3C6D80-EB58-4189-A15A-9D3FEA233AF0}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Release|Any CPU = Release|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {C01E227F-E396-45E7-A82F-478EFA9AC0A6}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {C01E227F-E396-45E7-A82F-478EFA9AC0A6}.Debug|Any CPU.Build.0 = Debug|Any CPU + {C01E227F-E396-45E7-A82F-478EFA9AC0A6}.Release|Any CPU.ActiveCfg = Release|Any CPU + {C01E227F-E396-45E7-A82F-478EFA9AC0A6}.Release|Any CPU.Build.0 = Release|Any CPU + {56052199-927F-46F5-8D0F-4826360E70B8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {56052199-927F-46F5-8D0F-4826360E70B8}.Debug|Any CPU.Build.0 = Debug|Any CPU + {56052199-927F-46F5-8D0F-4826360E70B8}.Release|Any CPU.ActiveCfg = Release|Any CPU + {56052199-927F-46F5-8D0F-4826360E70B8}.Release|Any CPU.Build.0 = Release|Any CPU + {97D93DB7-CEB6-4C21-B4C2-A5A98D3FD59C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {97D93DB7-CEB6-4C21-B4C2-A5A98D3FD59C}.Debug|Any CPU.Build.0 = Debug|Any CPU + {97D93DB7-CEB6-4C21-B4C2-A5A98D3FD59C}.Release|Any CPU.ActiveCfg = Release|Any CPU + {97D93DB7-CEB6-4C21-B4C2-A5A98D3FD59C}.Release|Any CPU.Build.0 = Release|Any CPU + {537A257C-0228-418F-9DD5-A46324E591AE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {537A257C-0228-418F-9DD5-A46324E591AE}.Debug|Any CPU.Build.0 = Debug|Any CPU + {537A257C-0228-418F-9DD5-A46324E591AE}.Release|Any CPU.ActiveCfg = Release|Any CPU + {537A257C-0228-418F-9DD5-A46324E591AE}.Release|Any CPU.Build.0 = Release|Any CPU + {2C70D969-A8AA-440B-81D8-532C327F237E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {2C70D969-A8AA-440B-81D8-532C327F237E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {2C70D969-A8AA-440B-81D8-532C327F237E}.Release|Any CPU.ActiveCfg = Release|Any CPU + {2C70D969-A8AA-440B-81D8-532C327F237E}.Release|Any CPU.Build.0 = Release|Any CPU + {E690FD52-65CD-4F11-A56E-A7D3B8D7A190}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E690FD52-65CD-4F11-A56E-A7D3B8D7A190}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E690FD52-65CD-4F11-A56E-A7D3B8D7A190}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E690FD52-65CD-4F11-A56E-A7D3B8D7A190}.Release|Any CPU.Build.0 = Release|Any CPU + {90663BC7-07FD-4089-9594-94D61D9F63A2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {90663BC7-07FD-4089-9594-94D61D9F63A2}.Debug|Any CPU.Build.0 = Debug|Any CPU + {90663BC7-07FD-4089-9594-94D61D9F63A2}.Release|Any CPU.ActiveCfg = Release|Any CPU + {90663BC7-07FD-4089-9594-94D61D9F63A2}.Release|Any CPU.Build.0 = Release|Any CPU + {8759AB44-DEC6-4E78-B64D-2EE4A403DFE1}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {8759AB44-DEC6-4E78-B64D-2EE4A403DFE1}.Debug|Any CPU.Build.0 = Debug|Any CPU + {8759AB44-DEC6-4E78-B64D-2EE4A403DFE1}.Release|Any CPU.ActiveCfg = Release|Any CPU + {8759AB44-DEC6-4E78-B64D-2EE4A403DFE1}.Release|Any CPU.Build.0 = Release|Any CPU + {85711FA3-547A-4B8E-AA23-95A6108F0DF8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {85711FA3-547A-4B8E-AA23-95A6108F0DF8}.Debug|Any CPU.Build.0 = Debug|Any CPU + {85711FA3-547A-4B8E-AA23-95A6108F0DF8}.Release|Any CPU.ActiveCfg = Release|Any CPU + {85711FA3-547A-4B8E-AA23-95A6108F0DF8}.Release|Any CPU.Build.0 = Release|Any CPU + {DF3C6D80-EB58-4189-A15A-9D3FEA233AF0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {DF3C6D80-EB58-4189-A15A-9D3FEA233AF0}.Debug|Any CPU.Build.0 = Debug|Any CPU + {DF3C6D80-EB58-4189-A15A-9D3FEA233AF0}.Release|Any CPU.ActiveCfg = Release|Any CPU + {DF3C6D80-EB58-4189-A15A-9D3FEA233AF0}.Release|Any CPU.Build.0 = Release|Any CPU + EndGlobalSection +EndGlobal diff --git a/drivers/spanner-ado-net/spanner-ado-net/AssemblyInfo.cs b/drivers/spanner-ado-net/spanner-ado-net/AssemblyInfo.cs new file mode 100644 index 00000000..9c773122 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/AssemblyInfo.cs @@ -0,0 +1,3 @@ +using System.Runtime.CompilerServices; +[assembly:InternalsVisibleTo("Google.Cloud.Spanner.DataProvider.Tests")] +[assembly:InternalsVisibleTo("Google.Cloud.Spanner.DataProvider.SpecificationTests")] diff --git a/drivers/spanner-ado-net/spanner-ado-net/Preconditions.cs b/drivers/spanner-ado-net/spanner-ado-net/Preconditions.cs new file mode 100644 index 00000000..c3b4867c --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/Preconditions.cs @@ -0,0 +1,25 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; + +namespace Google.Cloud.Spanner.DataProvider; + +internal static class Preconditions +{ + internal static int CheckIndexRange(int argument, string paramName, int minInclusive, int maxInclusive) => + argument < minInclusive || argument > maxInclusive ? + throw new IndexOutOfRangeException($"Value {argument} should be in range [{minInclusive}, {maxInclusive}]") : argument; + +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net/README.md b/drivers/spanner-ado-net/spanner-ado-net/README.md new file mode 100644 index 00000000..fbcfda13 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/README.md @@ -0,0 +1,5 @@ +# Spanner ADO.NET Data Provider + +ADO.NET Data Provider for Spanner. + +__ALPHA: Not for production use__ diff --git a/drivers/spanner-ado-net/spanner-ado-net/SpannerBatch.cs b/drivers/spanner-ado-net/spanner-ado-net/SpannerBatch.cs new file mode 100644 index 00000000..30ca2cab --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/SpannerBatch.cs @@ -0,0 +1,129 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Google.Api.Gax; +using Google.Cloud.Spanner.V1; + +namespace Google.Cloud.Spanner.DataProvider; + +/// +/// SpannerBatch is the Spanner-specific implementation of DbBatch. SpannerBatch supports batches of DML or DDL +/// statements. Note that all statements in a batch must be of the same type. Batches of queries or DML statements with +/// a THEN RETURN / RETURNING clause are not supported. +/// +public class SpannerBatch : DbBatch +{ + private SpannerConnection SpannerConnection => (SpannerConnection)Connection!; + protected override SpannerBatchCommandCollection DbBatchCommands { get; } = new(); + public override int Timeout { get; set; } + protected override DbConnection? DbConnection { get; set; } + protected override DbTransaction? DbTransaction { get; set; } + + public SpannerBatch() + {} + + internal SpannerBatch(SpannerConnection connection) + { + Connection = GaxPreconditions.CheckNotNull(connection, nameof(connection)); + } + + protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) + { + throw new System.NotImplementedException(); + } + + protected override Task ExecuteDbDataReaderAsync(CommandBehavior behavior, CancellationToken cancellationToken) + { + throw new System.NotImplementedException(); + } + + private List CreateStatements() + { + var statements = new List(DbBatchCommands.Count); + foreach (var command in DbBatchCommands) + { + var spannerParams = ((SpannerParameterCollection)command.Parameters).CreateSpannerParams(); + var queryParams = spannerParams.Item1; + var paramTypes = spannerParams.Item2; + var batchStatement = new ExecuteBatchDmlRequest.Types.Statement + { + Sql = command.CommandText, + Params = queryParams, + }; + batchStatement.ParamTypes.Add(paramTypes); + statements.Add(batchStatement); + } + return statements; + } + + public override int ExecuteNonQuery() + { + if (DbBatchCommands.Count == 0) + { + return 0; + } + var statements = CreateStatements(); + var results = SpannerConnection.LibConnection!.ExecuteBatch(statements); + DbBatchCommands.SetAffected(results); + return (int) results.Sum(); + } + + public override async Task ExecuteNonQueryAsync(CancellationToken cancellationToken = default) + { + if (DbBatchCommands.Count == 0) + { + return 0; + } + var statements = CreateStatements(); + var results = await SpannerConnection.LibConnection!.ExecuteBatchAsync(statements); + DbBatchCommands.SetAffected(results); + return (int) results.Sum(); + } + + public override object? ExecuteScalar() + { + throw new System.NotImplementedException(); + } + + public override Task ExecuteScalarAsync(CancellationToken cancellationToken = default) + { + throw new System.NotImplementedException(); + } + + public override void Prepare() + { + throw new System.NotImplementedException(); + } + + public override Task PrepareAsync(CancellationToken cancellationToken = default) + { + throw new System.NotImplementedException(); + } + + public override void Cancel() + { + throw new System.NotImplementedException(); + } + + protected override DbBatchCommand CreateDbBatchCommand() + { + return new SpannerBatchCommand(); + } +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net/SpannerBatchCommand.cs b/drivers/spanner-ado-net/spanner-ado-net/SpannerBatchCommand.cs new file mode 100644 index 00000000..4e93a7cd --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/SpannerBatchCommand.cs @@ -0,0 +1,34 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Data; +using System.Data.Common; + +namespace Google.Cloud.Spanner.DataProvider; + +public class SpannerBatchCommand : DbBatchCommand +{ + public override string CommandText { get; set; } = ""; + public override CommandType CommandType { get; set; } + + internal int InternalRecordsAffected; + public override int RecordsAffected => InternalRecordsAffected; + protected override DbParameterCollection DbParameterCollection { get; } = new SpannerParameterCollection(); + public override bool CanCreateParameter => true; + + public override DbParameter CreateParameter() + { + return new SpannerParameter(); + } +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net/SpannerBatchCommandCollection.cs b/drivers/spanner-ado-net/spanner-ado-net/SpannerBatchCommandCollection.cs new file mode 100644 index 00000000..84f91f9e --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/SpannerBatchCommandCollection.cs @@ -0,0 +1,95 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Collections.Generic; +using System.Data.Common; +using Google.Api.Gax; + +namespace Google.Cloud.Spanner.DataProvider; + +public class SpannerBatchCommandCollection : DbBatchCommandCollection +{ + private readonly List _commands = new (); + public override int Count => _commands.Count; + public override bool IsReadOnly => false; + + internal void SetAffected(long[] affected) + { + for (var i = 0; i < _commands.Count; i++) + { + _commands[i].InternalRecordsAffected = (int) affected[i]; + } + } + + public override IEnumerator GetEnumerator() + { + return _commands.GetEnumerator(); + } + + public override void Add(DbBatchCommand item) + { + GaxPreconditions.CheckNotNull(item, nameof(item)); + GaxPreconditions.CheckArgument(item is SpannerBatchCommand, nameof(item), "Item must be a SpannerBatchCommand"); + _commands.Add((SpannerBatchCommand)item); + } + + public override void Clear() + { + _commands.Clear(); + } + + public override bool Contains(DbBatchCommand item) + { + GaxPreconditions.CheckArgument(item is SpannerBatchCommand, nameof(item), "Item must be a SpannerBatchCommand"); + return _commands.Contains((SpannerBatchCommand)item); + } + + public override void CopyTo(DbBatchCommand[] array, int arrayIndex) + { + throw new System.NotImplementedException(); + } + + public override bool Remove(DbBatchCommand item) + { + GaxPreconditions.CheckArgument(item is SpannerBatchCommand, nameof(item), "Item must be a SpannerBatchCommand"); + return _commands.Remove((SpannerBatchCommand)item); + } + + public override int IndexOf(DbBatchCommand item) + { + GaxPreconditions.CheckArgument(item is SpannerBatchCommand, nameof(item), "Item must be a SpannerBatchCommand"); + return _commands.IndexOf((SpannerBatchCommand)item); + } + + public override void Insert(int index, DbBatchCommand item) + { + GaxPreconditions.CheckArgument(item is SpannerBatchCommand, nameof(item), "Item must be a SpannerBatchCommand"); + _commands.Insert(index, (SpannerBatchCommand)item); + } + + public override void RemoveAt(int index) + { + _commands.RemoveAt(index); + } + + protected override SpannerBatchCommand GetBatchCommand(int index) + { + return _commands[index]; + } + + protected override void SetBatchCommand(int index, DbBatchCommand batchCommand) + { + _commands[index] = (SpannerBatchCommand)batchCommand; + } +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net/SpannerCommand.cs b/drivers/spanner-ado-net/spanner-ado-net/SpannerCommand.cs new file mode 100644 index 00000000..bc2f0661 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/SpannerCommand.cs @@ -0,0 +1,388 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Data; +using System.Data.Common; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; +using Google.Api.Gax; +using Google.Cloud.Spanner.V1; +using Google.Cloud.SpannerLib; +using Google.Protobuf.WellKnownTypes; +using static Google.Cloud.Spanner.DataProvider.SpannerDbException; + +namespace Google.Cloud.Spanner.DataProvider; + +public class SpannerCommand : DbCommand, ICloneable +{ + private SpannerConnection SpannerConnection => (SpannerConnection)Connection!; + + private string _commandText = ""; + [AllowNull] public override string CommandText { get => _commandText; set => _commandText = value ?? ""; } + + private int? _timeout; + + public override int CommandTimeout + { + get => _timeout ?? (int?) SpannerConnection?.DefaultCommandTimeout ?? 0; + set => _timeout = value; + } + + public override CommandType CommandType { get; set; } = CommandType.Text; + + public override UpdateRowSource UpdatedRowSource { get; set; } = UpdateRowSource.Both; + protected override DbConnection? DbConnection { get; set; } + + protected override DbParameterCollection DbParameterCollection { get; } = new SpannerParameterCollection(); + public new SpannerParameterCollection Parameters => (SpannerParameterCollection)DbParameterCollection; + + SpannerTransaction? _transaction; + protected override DbTransaction? DbTransaction + { + get => _transaction; + set + { + var tx = (SpannerTransaction?)value; + + if (tx is { IsCompleted: true }) + throw new InvalidOperationException("Transaction is already completed"); + _transaction = tx; + } + } + + public override bool DesignTimeVisible { get; set; } + + private bool HasTransaction => DbTransaction is SpannerTransaction; + private readonly Mutation? _mutation; + + public TransactionOptions.Types.ReadOnly? SingleUseReadOnlyTransactionOptions { get; set; } + public RequestOptions? RequestOptions { get; set; } + + private bool _disposed; + + public SpannerCommand() {} + + internal SpannerCommand(SpannerConnection connection) + { + Connection = GaxPreconditions.CheckNotNull(connection, nameof(connection)); + } + + public SpannerCommand(string commandText, SpannerConnection connection) + { + Connection = GaxPreconditions.CheckNotNull(connection, nameof(connection)); + _commandText = GaxPreconditions.CheckNotNull(commandText, nameof(commandText)); + } + + public SpannerCommand(string cmdText, SpannerConnection connection, SpannerTransaction? transaction) + : this(cmdText, connection) + => Transaction = transaction; + + internal SpannerCommand(SpannerConnection connection, Mutation mutation) + { + Connection = GaxPreconditions.CheckNotNull(connection, nameof(connection)); + _mutation = mutation; + } + + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); + _disposed = true; + } + + private void CheckDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(SpannerCommand)); + } + } + + public override void Cancel() + { + // TODO: Implement in Spanner lib + } + + internal ExecuteSqlRequest BuildStatement(ExecuteSqlRequest.Types.QueryMode mode = ExecuteSqlRequest.Types.QueryMode.Normal) + { + GaxPreconditions.CheckState(!(HasTransaction && SingleUseReadOnlyTransactionOptions != null), + "Cannot set both a transaction and single-use read-only options"); + var spannerParams = ((SpannerParameterCollection)DbParameterCollection).CreateSpannerParams(); + var queryParams = spannerParams.Item1; + var paramTypes = spannerParams.Item2; + var sql = CommandText; + if (CommandType == CommandType.TableDirect) + { + // TODO: Quote the table name + sql = $"select * from {sql}"; + } + var statement = new ExecuteSqlRequest + { + Sql = sql, + Params = queryParams, + RequestOptions = RequestOptions, + QueryMode = mode, + }; + statement.ParamTypes.Add(paramTypes); + if (SingleUseReadOnlyTransactionOptions != null) + { + statement.Transaction = new TransactionSelector + { + SingleUse = new TransactionOptions + { + ReadOnly = SingleUseReadOnlyTransactionOptions, + }, + }; + } + + return statement; + } + + private Mutation BuildMutation() + { + GaxPreconditions.CheckNotNull(_mutation, nameof(_mutation)); + GaxPreconditions.CheckNotNull(SpannerConnection, nameof(SpannerConnection)); + GaxPreconditions.CheckState(!(HasTransaction && SingleUseReadOnlyTransactionOptions != null), + "Cannot set both a transaction and single-use read-only options"); + + var mutation = _mutation!.Clone(); + Mutation.Types.Write? write = null; + Mutation.Types.Delete? delete = mutation.OperationCase == Mutation.OperationOneofCase.Delete + ? mutation.Delete + : null; + switch (mutation.OperationCase) + { + case Mutation.OperationOneofCase.Insert: + write = mutation.Insert; + break; + case Mutation.OperationOneofCase.Update: + write = mutation.Update; + break; + case Mutation.OperationOneofCase.InsertOrUpdate: + write = mutation.InsertOrUpdate; + break; + case Mutation.OperationOneofCase.Replace: + write = mutation.Replace; + break; + } + + var values = new ListValue(); + for (var index = 0; index < DbParameterCollection.Count; index++) + { + var param = DbParameterCollection[index]; + if (param is SpannerParameter spannerParameter) + { + if (write != null) + { + var name = param.ParameterName; + if (name.StartsWith("@")) + { + name = name[1..]; + } + + write.Columns.Add(name); + } + + values.Values.Add(spannerParameter.ConvertToProto(spannerParameter)); + } + else + { + throw new ArgumentException("parameter is not a SpannerParameter: " + param.ParameterName); + } + } + + write?.Values.Add(values); + if (delete != null) + { + delete.KeySet = new KeySet(); + delete.KeySet.Keys.Add(values); + } + + return mutation; + } + + private void ExecuteMutation() + { + GaxPreconditions.CheckState(_mutation != null, "Cannot execute mutation"); + var mutations = new BatchWriteRequest.Types.MutationGroup + { + Mutations = { BuildMutation() } + }; + SpannerConnection.LibConnection!.WriteMutations(mutations); + } + + private Rows Execute(ExecuteSqlRequest.Types.QueryMode mode = ExecuteSqlRequest.Types.QueryMode.Normal) + { + CheckCommandStateForExecution(); + return TranslateException(() => SpannerConnection.LibConnection!.Execute(BuildStatement(mode))); + } + + private Task ExecuteAsync(CancellationToken cancellationToken) + { + return ExecuteAsync(ExecuteSqlRequest.Types.QueryMode.Normal, cancellationToken); + } + + private Task ExecuteAsync(ExecuteSqlRequest.Types.QueryMode mode, CancellationToken cancellationToken) + { + CheckCommandStateForExecution(); + if (cancellationToken.IsCancellationRequested) + { + return Task.FromCanceled(cancellationToken); + } + return TranslateException(() => SpannerConnection.LibConnection!.ExecuteAsync(BuildStatement(mode))); + } + + private void CheckCommandStateForExecution() + { + GaxPreconditions.CheckState(!string.IsNullOrEmpty(_commandText), "Cannot execute empty command"); + GaxPreconditions.CheckState(Connection != null, "No connection has been set for the command"); + GaxPreconditions.CheckState(Transaction == null || Transaction.Connection == SpannerConnection, + "The transaction that has been set for this command is from a different connection"); + } + + public override int ExecuteNonQuery() + { + CheckDisposed(); + if (_mutation != null) + { + ExecuteMutation(); + return 1; + } + + var rows = Execute(); + try + { + return (int)rows.UpdateCount; + } + finally + { + rows.Close(); + } + } + + public override object? ExecuteScalar() + { + CheckDisposed(); + GaxPreconditions.CheckState(_mutation == null, "Cannot execute mutations with ExecuteScalar()"); + var rows = Execute(); + using var reader = new SpannerDataReader(SpannerConnection, rows, CommandBehavior.Default); + if (reader.Read()) + { + if (reader.FieldCount > 0) + { + return reader.GetValue(0); + } + } + + return null; + } + + public override void Prepare() + { + CheckDisposed(); + Execute(ExecuteSqlRequest.Types.QueryMode.Plan); + } + + public override Task PrepareAsync(CancellationToken cancellationToken = default) + { + CheckDisposed(); + return ExecuteAsync(ExecuteSqlRequest.Types.QueryMode.Plan, cancellationToken); + } + + protected override DbParameter CreateDbParameter() + { + CheckDisposed(); + return new SpannerParameter(); + } + + protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior) + { + CheckDisposed(); + GaxPreconditions.CheckState(_mutation == null, "Cannot execute mutations with ExecuteDbDataReader()"); + try + { + var mode = behavior.HasFlag(CommandBehavior.SchemaOnly) + ? ExecuteSqlRequest.Types.QueryMode.Plan + : ExecuteSqlRequest.Types.QueryMode.Normal; + var rows = Execute(mode); + return new SpannerDataReader(SpannerConnection, rows, behavior); + } + catch (SpannerException exception) + { + if (behavior.HasFlag(CommandBehavior.CloseConnection)) + { + SpannerConnection.Close(); + } + throw new SpannerDbException(exception); + } + catch (Exception) + { + if (behavior.HasFlag(CommandBehavior.CloseConnection)) + { + SpannerConnection.Close(); + } + throw; + } + } + + protected override async Task ExecuteDbDataReaderAsync(CommandBehavior behavior, + CancellationToken cancellationToken) + { + CheckDisposed(); + GaxPreconditions.CheckState(_mutation == null, "Cannot execute mutations with ExecuteDbDataReader()"); + try + { + var mode = behavior.HasFlag(CommandBehavior.SchemaOnly) + ? ExecuteSqlRequest.Types.QueryMode.Plan + : ExecuteSqlRequest.Types.QueryMode.Normal; + var rows = await ExecuteAsync(mode, cancellationToken); + return new SpannerDataReader(SpannerConnection, rows, behavior); + } + catch (SpannerException exception) + { + if (behavior.HasFlag(CommandBehavior.CloseConnection)) + { + SpannerConnection.Close(); + } + throw new SpannerDbException(exception); + } + catch (Exception) + { + if (behavior.HasFlag(CommandBehavior.CloseConnection)) + { + await SpannerConnection.CloseAsync(); + } + throw; + } + } + + object ICloneable.Clone() => Clone(); + + public virtual SpannerCommand Clone() + { + var clone = new SpannerCommand() + { + Connection = Connection, + _commandText = _commandText, + _transaction = _transaction, + CommandTimeout = CommandTimeout, + CommandType = CommandType, + DesignTimeVisible = DesignTimeVisible, + }; + (DbParameterCollection as SpannerParameterCollection)?.CloneTo((clone.Parameters as SpannerParameterCollection)!); + return clone; + } + +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net/SpannerCommandBuilder.cs b/drivers/spanner-ado-net/spanner-ado-net/SpannerCommandBuilder.cs new file mode 100644 index 00000000..67a0bd9b --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/SpannerCommandBuilder.cs @@ -0,0 +1,50 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Data; +using System.Data.Common; + +namespace Google.Cloud.Spanner.DataProvider; + +/// +/// This class is currently not supported. +/// All methods in this class throw a NotImplementedException. +/// +public class SpannerCommandBuilder : DbCommandBuilder +{ + protected override void ApplyParameterInfo(DbParameter parameter, DataRow row, StatementType statementType, bool whereClause) + { + throw new System.NotImplementedException(); + } + + protected override string GetParameterName(int parameterOrdinal) + { + throw new System.NotImplementedException(); + } + + protected override string GetParameterName(string parameterName) + { + throw new System.NotImplementedException(); + } + + protected override string GetParameterPlaceholder(int parameterOrdinal) + { + throw new System.NotImplementedException(); + } + + protected override void SetRowUpdatingHandler(DbDataAdapter adapter) + { + throw new System.NotImplementedException(); + } +} diff --git a/drivers/spanner-ado-net/spanner-ado-net/SpannerConnection.cs b/drivers/spanner-ado-net/spanner-ado-net/SpannerConnection.cs new file mode 100644 index 00000000..d2f257a8 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/SpannerConnection.cs @@ -0,0 +1,365 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using System.Threading; +using System.Threading.Tasks; +using Google.Api.Gax; +using Google.Cloud.Spanner.Common.V1; +using Google.Cloud.Spanner.V1; +using Google.Cloud.SpannerLib; + +namespace Google.Cloud.Spanner.DataProvider; + +public class SpannerConnection : DbConnection +{ + public bool UseSharedLibrary { get; set; } + + private string _connectionString = string.Empty; + + private SpannerConnectionStringBuilder? _connectionStringBuilder; + + [AllowNull] + public sealed override string ConnectionString { + get => _connectionString; + set + { + AssertClosed(); + if (string.IsNullOrWhiteSpace(value)) + { + _connectionStringBuilder = null; + _connectionString = string.Empty; + } + else + { + var builder = new SpannerConnectionStringBuilder(value); + builder.CheckValid(); + _connectionStringBuilder = builder; + _connectionString = value; + } + } + } + + public override string Database + { + get + { + if (string.IsNullOrWhiteSpace(ConnectionString) || _connectionStringBuilder == null) + { + return ""; + } + if (!string.IsNullOrEmpty(_connectionStringBuilder.DataSource)) + { + return _connectionStringBuilder.DataSource; + } + if (!string.IsNullOrEmpty(_connectionStringBuilder.Project) && + !string.IsNullOrEmpty(_connectionStringBuilder.Instance) && + !string.IsNullOrEmpty(_connectionStringBuilder.Project)) + { + return $"projects/{_connectionStringBuilder.Project}/instances/{_connectionStringBuilder.Instance}/databases/{_connectionStringBuilder.Database}"; + } + return ""; + } + } + + private ConnectionState InternalState + { + get => _state; + set + { + var originalState = _state; + _state = value; + OnStateChange(new StateChangeEventArgs(originalState, _state)); + } + } + + public override ConnectionState State => InternalState; + protected override DbProviderFactory DbProviderFactory => SpannerFactory.Instance; + + public override string DataSource => _connectionStringBuilder?.DataSource ?? string.Empty; + + public override string ServerVersion + { + get + { + AssertOpen(); + // TODO: Return an actual version number + return "1.0.0"; + } + } + + internal Version ServerVersionNormalized => Version.Parse(ServerVersion); + + internal string ServerVersionNormalizedString => FormattableString.Invariant($"{ServerVersionNormalized.Major:000}.{ServerVersionNormalized.Minor:000}.{ServerVersionNormalized.Build:0000}"); + + public override bool CanCreateBatch => true; + + private bool _disposed; + private ConnectionState _state = ConnectionState.Closed; + private SpannerPool? Pool { get; set; } + + private Connection? _libConnection; + + internal Connection? LibConnection + { + get + { + AssertOpen(); + return _libConnection; + } + } + + internal uint DefaultCommandTimeout => _connectionStringBuilder?.CommandTimeout ?? 0; + + private SpannerTransaction? _transaction; + + private SpannerSchemaProvider? _mSchemaProvider; + + private SpannerSchemaProvider GetSchemaProvider() => _mSchemaProvider ??= new SpannerSchemaProvider(this); + + public SpannerConnection() + { + } + + public SpannerConnection(string? connectionString) + { + ConnectionString = connectionString; + } + + public SpannerConnection(SpannerConnectionStringBuilder connectionStringBuilder) + { + GaxPreconditions.CheckNotNull(connectionStringBuilder, nameof(connectionStringBuilder)); + connectionStringBuilder.CheckValid(); + _connectionStringBuilder = connectionStringBuilder; + _connectionString = connectionStringBuilder.ConnectionString; + } + + protected override DbTransaction BeginDbTransaction(IsolationLevel isolationLevel) + { + return BeginTransaction(new TransactionOptions + { + IsolationLevel = SpannerTransaction.TranslateIsolationLevel(isolationLevel), + }); + } + + public DbTransaction BeginReadOnlyTransaction() + { + return BeginTransaction(new TransactionOptions + { + ReadOnly = new TransactionOptions.Types.ReadOnly(), + }); + } + + /// + /// Start a new transaction using the given TransactionOptions. + /// + /// The options to use for the new transaction + /// The new transaction + /// If the connection has an active transaction + public DbTransaction BeginTransaction(TransactionOptions transactionOptions) + { + EnsureOpen(); + GaxPreconditions.CheckState(!HasTransaction, "This connection has a transaction."); + _transaction = new SpannerTransaction(this, transactionOptions); + return _transaction; + } + + internal void ClearTransaction() + { + _transaction = null; + } + + internal bool HasTransaction => _transaction != null; + + public override void ChangeDatabase(string databaseName) + { + GaxPreconditions.CheckNotNullOrEmpty(databaseName, nameof(databaseName)); + GaxPreconditions.CheckState(!HasTransaction, "Cannot change database when a transaction is open"); + if (_connectionStringBuilder == null) + { + ConnectionString = $"Data Source={databaseName}"; + return; + } + if (DatabaseName.TryParse(databaseName, allowUnparsed: false, out _)) + { + _connectionStringBuilder.DataSource = databaseName; + } + else + { + if (DatabaseName.TryParse(_connectionStringBuilder.DataSource, out var currentDatabase)) + { + _connectionStringBuilder.DataSource = $"projects/{currentDatabase.ProjectId}/instances/{currentDatabase.InstanceId}/databases/{databaseName}"; + } + else if (!string.IsNullOrEmpty(_connectionStringBuilder.Project) && !string.IsNullOrEmpty(_connectionStringBuilder.Instance)) + { + _connectionStringBuilder.Database = databaseName; + } + else + { + throw new ArgumentException($"Invalid database name: {databaseName}"); + } + } + if (_state == ConnectionState.Open) + { + Close(); + Open(); + } + } + + public override void Close() + { + if (InternalState == ConnectionState.Closed) + { + return; + } + + InternalState = ConnectionState.Closed; + _libConnection?.Close(); + _libConnection = null; + } + + protected override void Dispose(bool disposing) + { + if (_disposed) + { + return; + } + if (disposing) + { + Close(); + } + base.Dispose(disposing); + _disposed = true; + } + + public override void Open() + { + AssertClosed(); + if (ConnectionString == string.Empty || _connectionStringBuilder == null) + { + throw new InvalidOperationException("Connection string is empty"); + } + + try + { + InternalState = ConnectionState.Connecting; + Pool = SpannerPool.GetOrCreate(_connectionStringBuilder.SpannerLibConnectionString); + _libConnection = Pool.CreateConnection(); + InternalState = ConnectionState.Open; + } + catch (Exception) + { + InternalState = ConnectionState.Closed; + throw; + } + } + + private void EnsureOpen() + { + if (InternalState == ConnectionState.Closed) + { + Open(); + } + } + + private void AssertOpen() + { + if (InternalState != ConnectionState.Open) + { + throw new InvalidOperationException("Connection is not open"); + } + } + + private void AssertClosed() + { + if (InternalState != ConnectionState.Closed) + { + throw new InvalidOperationException("Connection is not closed"); + } + } + + public CommitResponse? WriteMutations(BatchWriteRequest.Types.MutationGroup mutations) + { + EnsureOpen(); + return LibConnection!.WriteMutations(mutations); + } + + public Task WriteMutationsAsync(BatchWriteRequest.Types.MutationGroup mutations, CancellationToken cancellationToken = default) + { + EnsureOpen(); + return LibConnection!.WriteMutationsAsync(mutations, cancellationToken); + } + + public new SpannerCommand CreateCommand() => (SpannerCommand) base.CreateCommand(); + + protected override DbCommand CreateDbCommand() + { + var cmd = new SpannerCommand(this); + return cmd; + } + + protected override DbBatch CreateDbBatch() + { + return new SpannerBatch(this); + } + + public long[] ExecuteBatchDml(List commands) + { + EnsureOpen(); + var statements = new List(commands.Count); + foreach (var command in commands) + { + if (command is SpannerCommand spannerCommand) + { + var statement = spannerCommand.BuildStatement(); + var batchStatement = new ExecuteBatchDmlRequest.Types.Statement + { + Sql = statement.Sql, + Params = statement.Params, + }; + batchStatement.ParamTypes.Add(statement.ParamTypes); + statements.Add(batchStatement); + } + } + return LibConnection!.ExecuteBatch(statements); + } + + public DbCommand CreateInsertCommand(string table) + { + return new SpannerCommand(this, new Mutation { Insert = new Mutation.Types.Write { Table = table } }); + } + + public DbCommand CreateUpdateCommand(string table) + { + return new SpannerCommand(this, new Mutation { Update = new Mutation.Types.Write { Table = table } }); + } + + public DbCommand CreateDeleteCommand(string table) + { + return new SpannerCommand(this, new Mutation { Delete = new Mutation.Types.Delete { Table = table } }); + } + + public override DataTable GetSchema() => GetSchemaProvider().GetSchema(); + + public override DataTable GetSchema(string collectionName) + => GetSchema(collectionName, null); + + public override DataTable GetSchema(string collectionName, string?[]? restrictionValues) + => GetSchemaProvider().GetSchema(collectionName, restrictionValues); +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net/SpannerConnectionStringBuilder.cs b/drivers/spanner-ado-net/spanner-ado-net/SpannerConnectionStringBuilder.cs new file mode 100644 index 00000000..b00bca5c --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/SpannerConnectionStringBuilder.cs @@ -0,0 +1,573 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.ComponentModel; +using System.Data; +using System.Data.Common; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.Linq; +using System.Text; + +namespace Google.Cloud.Spanner.DataProvider; + +public class SpannerConnectionStringBuilder : DbConnectionStringBuilder +{ + /// + /// The fully qualified name of the Spanner database to connect to. + /// Example: projects/my-project/instances/my-instance/databases/my-database + /// + [Category("Connection")] + [Description("The fully qualified name of the database to use. This property takes precedence over any Project, Instance, or Database that has been set in the connection string.")] + [DisplayName("Data Source")] + public string DataSource + { + get => SpannerConnectionStringOption.DataSource.GetValue(this); + set => SpannerConnectionStringOption.DataSource.SetValue(this, value); + } + + /// + /// The name of the Spanner instance to connect to. + /// + [Category("Connection")] + [Description("The name of the Google Cloud project to use.")] + [DisplayName("Project")] + public string Project + { + get => SpannerConnectionStringOption.Project.GetValue(this); + set => SpannerConnectionStringOption.Project.SetValue(this, value); + } + + /// + /// The name of the Spanner instance to connect to. + /// + [Category("Connection")] + [Description("The name of the Spanner instance to use.")] + [DisplayName("Instance")] + public string Instance + { + get => SpannerConnectionStringOption.Instance.GetValue(this); + set => SpannerConnectionStringOption.Instance.SetValue(this, value); + } + + /// + /// The name of the Spanner database to connect to. + /// + [Category("Connection")] + [Description("The name of the database to use")] + [DisplayName("Database")] + public string Database + { + get => SpannerConnectionStringOption.Database.GetValue(this); + set => SpannerConnectionStringOption.Database.SetValue(this, value); + } + + /// + /// The hostname or IP address of the Spanner server to connect to. + /// + [Category("Connection")] + [Description("The hostname or IP address of the Spanner server to connect to.")] + [DefaultValue("")] + [DisplayName("Host")] + public string Host + { + get => SpannerConnectionStringOption.Host.GetValue(this); + set => SpannerConnectionStringOption.Host.SetValue(this, value); + } + + /// + /// The TCP port of the Spanner server to connect to. + /// + [Category("Connection")] + [DefaultValue(443u)] + [Description("The TCP port of the Spanner server to connect to.")] + [DisplayName("Port")] + public uint Port + { + get => SpannerConnectionStringOption.Port.GetValue(this); + set => SpannerConnectionStringOption.Port.SetValue(this, value); + } + + /// + /// Whether to use plain text communication with the server. The default is SSL. + /// + [Category("Connection")] + [DefaultValue(false)] + [Description("Whether to use plain text or SSL (default).")] + [DisplayName("UsePlainText")] + public bool UsePlainText + { + get => SpannerConnectionStringOption.UsePlainText.GetValue(this); + set => SpannerConnectionStringOption.UsePlainText.SetValue(this, value); + } + + /// + /// The time in milliseconds to wait for a connection before terminating the attempt and generating an error. + /// The default value is 15000 (15 seconds). + /// + [Category("Timeout")] + [Description("The time in milliseconds to wait for a connection before terminating the attempt and generating an error.")] + [DefaultValue(15000u)] + [DisplayName("Connection Timeout")] + public uint ConnectionTimeout + { + get => SpannerConnectionStringOption.ConnectionTimeout.GetValue(this); + set => SpannerConnectionStringOption.ConnectionTimeout.SetValue(this, value); + } + + /// + /// The time in milliseconds to wait for a command before terminating the attempt and generating an error. + /// The default value is 0, which means that the command should use the default timeout set by Spanner. + /// + [Category("Timeout")] + [Description("The time in milliseconds to wait for a command before terminating the attempt and generating an error.")] + [DefaultValue(0u)] + [DisplayName("Command Timeout")] + public uint CommandTimeout + { + get => SpannerConnectionStringOption.CommandTimeout.GetValue(this); + set => SpannerConnectionStringOption.CommandTimeout.SetValue(this, value); + } + + /// + /// The maximum time in milliseconds that a read/write transaction may take to execute. + /// The default value is 0, which means that there is no transaction timeout. + /// + [Category("Timeout")] + [Description("The maximum time in milliseconds that a read/write transaction may take to execute.")] + [DefaultValue(0u)] + [DisplayName("Transaction Timeout")] + public uint TransactionTimeout + { + get => SpannerConnectionStringOption.TransactionTimeout.GetValue(this); + set => SpannerConnectionStringOption.TransactionTimeout.SetValue(this, value); + } + + /// + /// The default isolation level that should be used for transactions on connections created from this connection + /// string. + /// + [Category("Transaction")] + [Description("The default isolation level to use for transactions on this connection.")] + [DefaultValue(IsolationLevel.Unspecified)] + [DisplayName("DefaultIsolationLevel")] + public IsolationLevel DefaultIsolationLevel + { + get => SpannerConnectionStringOption.DefaultIsolationLevel.GetValue(this); + set => SpannerConnectionStringOption.DefaultIsolationLevel.SetValue(this, value); + } + + /// + /// The search_path that should be used by the connection. + /// + [Category("Options")] + [Description("The search path for this connection.")] + [DefaultValue("")] + [DisplayName("SearchPath")] + public string SearchPath + { + get => SpannerConnectionStringOption.SearchPath.GetValue(this); + set => SpannerConnectionStringOption.SearchPath.SetValue(this, value); + } + + /// + /// Any other options that should be set for the connection in the format key1=value1;key2=value2;... + /// + [Category("Options")] + [Description("Any additional options to set for the connection.")] + [DefaultValue("")] + [DisplayName("Options")] + public string Options + { + get => SpannerConnectionStringOption.Options.GetValue(this); + set => SpannerConnectionStringOption.Options.SetValue(this, value); + } + + /// + /// Returns an that contains the keys in the . + /// + public override ICollection Keys => base.Keys.Cast().OrderBy(static x => SpannerConnectionStringOption.OptionNames.IndexOf(x)).ToList(); + + /// + /// Whether this contains a set option with the specified name. + /// + /// The option name. + /// true if an option with that name is set; otherwise, false. + public override bool ContainsKey(string keyword) => + SpannerConnectionStringOption.TryGetOptionForKey(keyword) is { } option && base.ContainsKey(option.Key); + + /// + /// Removes the option with the specified name. + /// + /// The option name. + public override bool Remove(string keyword) => + SpannerConnectionStringOption.TryGetOptionForKey(keyword) is { } option && base.Remove(option.Key); + + /// + /// Retrieves an option value by name. + /// + /// The option name. + /// That option's value, if set. + [AllowNull] + public override object this[string key] + { + get + { + var option = SpannerConnectionStringOption.TryGetOptionForKey(key); + return option == null ? base[key] : option.GetObject(this); + } + set + { + var option = SpannerConnectionStringOption.TryGetOptionForKey(key); + if (option == null) + { + base[key] = value; + } + else + { + if (value is null) + { + base[option.Key] = null; + } + else + { + option.SetObject(this, value); + } + } + } + } + + public SpannerConnectionStringBuilder() + { + } + + public SpannerConnectionStringBuilder(string connectionString) + { + ConnectionString = connectionString; + } + + internal void DoSetValue(string key, object? value) => base[key] = value; + + internal SpannerConnectionStringBuilder Clone() => new(ConnectionString); + + internal void CheckValid() + { + if (string.IsNullOrEmpty(ConnectionString)) + { + throw new ArgumentException("Empty connection string"); + } + if (string.IsNullOrEmpty(DataSource)) + { + if (string.IsNullOrEmpty(Project) || string.IsNullOrEmpty(Instance) || string.IsNullOrEmpty(Database)) + { + throw new ArgumentException("The connection string must either contain a Data Source or a Project, Instance, and Database name"); + } + } + } + + internal string SpannerLibConnectionString + { + get + { + CheckValid(); + var builder = new StringBuilder(); + if (Host != "") + { + builder.Append(Host); + if (Port != 443) + { + builder.Append(":"); + builder.Append(Port); + } + builder.Append('/'); + } + if (DataSource != "") + { + builder.Append(DataSource); + } + else if (Project != "" && Instance != "" && Database != "") + { + builder.Append("projects/").Append(Project); + builder.Append("/instances/").Append(Instance); + builder.Append("/databases/").Append(Database); + } + else + { + throw new ArgumentException("Invalid connection string. Either Data Source or Project, Instance, and Database must be specified."); + } + foreach (var key in Keys.Cast()) + { + if (SpannerConnectionStringOption.SOptions.ContainsKey(key)) + { + var option = SpannerConnectionStringOption.SOptions[key]; + if (option.SpannerLibKey != "") + { + builder.Append(';').Append(option.SpannerLibKey).Append('=').Append(this[key]); + } + else if (key == "Options") + { + builder.Append(';').Append(this[key]); + } + } + else + { + builder.Append(';').Append(key).Append('=').Append(this[key]); + } + } + return builder.ToString(); + } + } + +} + +internal abstract class SpannerConnectionStringOption +{ + public static List OptionNames { get; } = []; + + // Connection Options + public static readonly SpannerConnectionStringReferenceOption DataSource; + public static readonly SpannerConnectionStringReferenceOption Host; + public static readonly SpannerConnectionStringValueOption Port; + public static readonly SpannerConnectionStringReferenceOption Project; + public static readonly SpannerConnectionStringReferenceOption Instance; + public static readonly SpannerConnectionStringReferenceOption Database; + + // Timeout Options + public static readonly SpannerConnectionStringValueOption ConnectionTimeout; + public static readonly SpannerConnectionStringValueOption CommandTimeout; + public static readonly SpannerConnectionStringValueOption TransactionTimeout; + + // SSL/TLS Options + public static readonly SpannerConnectionStringValueOption UsePlainText; + + // Transaction Options + public static readonly SpannerConnectionStringValueOption DefaultIsolationLevel; + + // Other options + public static readonly SpannerConnectionStringReferenceOption SearchPath; + public static readonly SpannerConnectionStringReferenceOption Options; + + public static SpannerConnectionStringOption? TryGetOptionForKey(string key) => SOptions.GetValueOrDefault(key); + + public static SpannerConnectionStringOption GetOptionForKey(string key) => + TryGetOptionForKey(key) ?? throw new ArgumentException($"Option '{key}' not supported."); + + public string Key => _keys[0]; + public IReadOnlyList Keys => _keys; + + internal string SpannerLibKey { get; } + + public abstract object GetObject(SpannerConnectionStringBuilder builder); + public abstract void SetObject(SpannerConnectionStringBuilder builder, object value); + + protected SpannerConnectionStringOption(IReadOnlyList keys) : this(keys, keys[0]) + { + } + + protected SpannerConnectionStringOption(IReadOnlyList keys, string spannerLibKey) + { + _keys = keys; + SpannerLibKey = spannerLibKey; + } + + private static void AddOption(Dictionary options, SpannerConnectionStringOption option) + { + foreach (var key in option._keys) + { + options.Add(key, option); + } + OptionNames.Add(option._keys[0]); + } + + static SpannerConnectionStringOption() + { + var options = new Dictionary(StringComparer.OrdinalIgnoreCase); + + // Base Options + AddOption(options, DataSource = new( + keys: ["Data Source", "DataSource"], + spannerLibKey: "", + defaultValue: "")); + + AddOption(options, Host = new( + keys: ["Host", "Server"], + spannerLibKey: "", + defaultValue: "")); + + AddOption(options, Port = new( + keys: ["Port"], + spannerLibKey: "", + defaultValue: 443u)); + + AddOption(options, Project = new( + keys: ["Project"], + spannerLibKey: "", + defaultValue: "")); + + AddOption(options, Instance = new( + keys: ["Instance"], + spannerLibKey: "", + defaultValue: "")); + + AddOption(options, Database = new( + keys: ["Database", "Initial Catalog"], + spannerLibKey: "", + defaultValue: "")); + + // Timeout Options + AddOption(options, ConnectionTimeout = new( + keys: ["Connection Timeout", "ConnectionTimeout", "Connect Timeout", "connect_timeout"], + spannerLibKey: "connect_timeout", + defaultValue: 15000u)); + + AddOption(options, CommandTimeout = new( + keys: ["Command Timeout", "CommandTimeout", "command_timeout", "statement_timeout"], + spannerLibKey: "statement_timeout", + defaultValue: 0u)); + + AddOption(options, TransactionTimeout = new( + keys: ["Transaction Timeout", "TransactionTimeout", "transaction_timeout"], + spannerLibKey: "transaction_timeout", + defaultValue: 0u)); + + // SSL/TLS Options + AddOption(options, UsePlainText = new( + keys: ["UsePlainText", "Use plain text", "Plain text", "use_plain_text"], + defaultValue: false)); + + // Transaction Options + AddOption(options, DefaultIsolationLevel = new( + keys: ["DefaultIsolationLevel", "default_isolation_level"], + defaultValue: IsolationLevel.Unspecified)); + + // Other options + AddOption(options, SearchPath = new( + keys: ["SearchPath", "search_path"], + spannerLibKey: "search_path", + defaultValue: "")); + + // Other options + AddOption(options, Options = new( + keys: ["Options"], + spannerLibKey: "", + defaultValue: "")); + + SOptions = options.ToFrozenDictionary(StringComparer.OrdinalIgnoreCase); + } + + internal static readonly FrozenDictionary SOptions; + + private readonly IReadOnlyList _keys; +} + +internal sealed class SpannerConnectionStringValueOption : SpannerConnectionStringOption + where T : struct +{ + public SpannerConnectionStringValueOption(IReadOnlyList keys, T defaultValue, Func? coerce = null) + : this(keys, keys[0], defaultValue, coerce) + { + } + + public SpannerConnectionStringValueOption(IReadOnlyList keys, string spannerLibKey, T defaultValue, Func? coerce = null) + : base(keys, spannerLibKey) + { + DefaultValue = defaultValue; + _coerce = coerce; + } + + public T DefaultValue { get; } + + public T GetValue(SpannerConnectionStringBuilder builder) => + builder.TryGetValue(Key, out var objectValue) ? ChangeType(objectValue) : DefaultValue; + + public void SetValue(SpannerConnectionStringBuilder builder, T value) => + builder.DoSetValue(Key, _coerce is null ? value : _coerce(value)); + + public override object GetObject(SpannerConnectionStringBuilder builder) => GetValue(builder); + + public override void SetObject(SpannerConnectionStringBuilder builder, object value) => SetValue(builder, ChangeType(value)); + + private T ChangeType(object objectValue) + { + if (typeof(T) == typeof(bool) && objectValue is string booleanString) + { + if (string.Equals(booleanString, "yes", StringComparison.OrdinalIgnoreCase)) + { + return (T)(object)true; + } + if (string.Equals(booleanString, "on", StringComparison.OrdinalIgnoreCase)) + { + return (T)(object)true; + } + if (string.Equals(booleanString, "no", StringComparison.OrdinalIgnoreCase)) + { + return (T)(object)false; + } + if (string.Equals(booleanString, "off", StringComparison.OrdinalIgnoreCase)) + { + return (T)(object)false; + } + } + + if (typeof(T).IsEnum && objectValue is string enumString) + { + enumString = enumString.Trim().Replace("_", "").Replace(" ", ""); + return (T)Enum.Parse(typeof(T), enumString, ignoreCase: true); + } + + try + { + return (T) Convert.ChangeType(objectValue, typeof(T), CultureInfo.InvariantCulture); + } + catch (Exception ex) + { + var exceptionMessage = string.Create(CultureInfo.InvariantCulture, $"Invalid value '{objectValue}' for '{Key}' connection string option."); + throw new ArgumentException(exceptionMessage, ex); + } + } + + private readonly Func? _coerce; +} + +internal sealed class SpannerConnectionStringReferenceOption : SpannerConnectionStringOption + where T : class +{ + public SpannerConnectionStringReferenceOption(IReadOnlyList keys, string spannerLibKey, T defaultValue, Func? coerce = null) + : base(keys, spannerLibKey) + { + DefaultValue = defaultValue; + _coerce = coerce; + } + + public T DefaultValue { get; } + + public T GetValue(SpannerConnectionStringBuilder builder) => + builder.TryGetValue(Key, out var objectValue) ? ChangeType(objectValue) : DefaultValue; + + public void SetValue(SpannerConnectionStringBuilder builder, T? value) => + builder.DoSetValue(Key, _coerce is null ? value : _coerce(value)); + + public override object GetObject(SpannerConnectionStringBuilder builder) => GetValue(builder); + + public override void SetObject(SpannerConnectionStringBuilder builder, object value) => SetValue(builder, ChangeType(value)); + + private static T ChangeType(object objectValue) => + (T) Convert.ChangeType(objectValue, typeof(T), CultureInfo.InvariantCulture); + + private readonly Func? _coerce; +} diff --git a/drivers/spanner-ado-net/spanner-ado-net/SpannerDataAdapter.cs b/drivers/spanner-ado-net/spanner-ado-net/SpannerDataAdapter.cs new file mode 100644 index 00000000..f69e85fc --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/SpannerDataAdapter.cs @@ -0,0 +1,26 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Data.Common; + +namespace Google.Cloud.Spanner.DataProvider; + +// SpannerDataAdapter does not do anything, as Spanner does not return base table and key column information for simple +// select statements. +// +// One possible way to implement it could be to only support it in combination with a TableDirect command. +public class SpannerDataAdapter : DbDataAdapter +{ + +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net/SpannerDataReader.cs b/drivers/spanner-ado-net/spanner-ado-net/SpannerDataReader.cs new file mode 100644 index 00000000..902a9f5a --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/SpannerDataReader.cs @@ -0,0 +1,883 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Data; +using System.Data.Common; +using System.Globalization; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using System.Xml; +using Google.Api.Gax; +using Google.Cloud.Spanner.V1; +using Google.Cloud.SpannerLib; +using Google.Protobuf.WellKnownTypes; +using TypeCode = Google.Cloud.Spanner.V1.TypeCode; + +namespace Google.Cloud.Spanner.DataProvider; + +public class SpannerDataReader : DbDataReader +{ + private readonly SpannerConnection _connection; + private readonly CommandBehavior _commandBehavior; + private bool IsSingleRow => _commandBehavior.HasFlag(CommandBehavior.SingleRow); + private Rows LibRows { get; } + private bool _closed; + private bool _hasReadData; + private bool _hasData; + + public override int FieldCount + { + get + { + CheckNotClosed(); + return LibRows.Metadata?.RowType.Fields.Count ?? 0; + } + } + + public override object this[int ordinal] => GetFieldValue(ordinal); + public override object this[string name] => this[GetOrdinal(name)]; + + public override int RecordsAffected + { + get + { + CheckNotClosed(); + return (int)LibRows.UpdateCount; + } + } + + public override bool HasRows + { + get + { + CheckNotClosed(); + if (LibRows.Metadata?.RowType.Fields.Count == 0) + { + return false; + } + if (_hasReadData) + { + return _hasData; + } + return CheckForRows(); + } + } + public override bool IsClosed => _closed; + public override int Depth => 0; + + private ListValue? _currentRow; + private ListValue? _tempRow; + + internal SpannerDataReader(SpannerConnection connection, Rows libRows, CommandBehavior behavior) + { + _connection = connection; + LibRows = libRows; + _commandBehavior = behavior; + } + + private void CheckNotClosed() + { + GaxPreconditions.CheckState(!_closed, "Reader has been closed"); + } + + public override void Close() + { + if (_closed) + { + return; + } + + _closed = true; + LibRows.Close(); + if (_commandBehavior.HasFlag(CommandBehavior.CloseConnection)) + { + _connection.Close(); + } + } + + public override bool Read() + { + if (!InternalRead()) + { + _hasReadData = true; + _currentRow = LibRows.Next(); + } + _hasData = _hasData || _currentRow != null; + return _currentRow != null; + } + + public override async Task ReadAsync(CancellationToken cancellationToken) + { + try + { + if (!InternalRead()) + { + _hasReadData = true; + _currentRow = await LibRows.NextAsync(cancellationToken); + } + _hasData = _hasData || _currentRow != null; + return _currentRow != null; + } + catch (SpannerException exception) + { + throw SpannerDbException.TranslateException(exception); + } + } + + private bool InternalRead() + { + CheckNotClosed(); + if (_tempRow != null) + { + _currentRow = _tempRow; + _tempRow = null; + _hasReadData = true; + return true; + } + if (IsSingleRow && _hasReadData) + { + _currentRow = null; + return true; + } + return false; + } + + private bool CheckForRows() + { + _tempRow ??= LibRows.Next(); + return _tempRow != null; + } + + public override DataTable? GetSchemaTable() + { + CheckNotClosed(); + var metadata = LibRows.Metadata; + if (metadata?.RowType == null || metadata.RowType.Fields.Count == 0) + { + return null; + } + var table = new DataTable("SchemaTable"); + + table.Columns.Add("ColumnName", typeof(string)); + table.Columns.Add("ColumnOrdinal", typeof(int)); + table.Columns.Add("ColumnSize", typeof(int)); + table.Columns.Add("NumericPrecision", typeof(int)); + table.Columns.Add("NumericScale", typeof(int)); + table.Columns.Add("IsUnique", typeof(bool)); + table.Columns.Add("IsKey", typeof(bool)); + table.Columns.Add("BaseServerName", typeof(string)); + table.Columns.Add("BaseCatalogName", typeof(string)); + table.Columns.Add("BaseColumnName", typeof(string)); + table.Columns.Add("BaseSchemaName", typeof(string)); + table.Columns.Add("BaseTableName", typeof(string)); + table.Columns.Add("DataType", typeof(System.Type)); + table.Columns.Add("AllowDBNull", typeof(bool)); + table.Columns.Add("ProviderType", typeof(int)); + table.Columns.Add("IsAliased", typeof(bool)); + table.Columns.Add("IsExpression", typeof(bool)); + table.Columns.Add("IsIdentity", typeof(bool)); + table.Columns.Add("IsAutoIncrement", typeof(bool)); + table.Columns.Add("IsRowVersion", typeof(bool)); + table.Columns.Add("IsHidden", typeof(bool)); + table.Columns.Add("IsLong", typeof(bool)); + table.Columns.Add("IsReadOnly", typeof(bool)); + table.Columns.Add("ProviderSpecificDataType", typeof(System.Type)); + table.Columns.Add("DataTypeName", typeof(string)); + + var ordinal = 0; + foreach (var column in metadata.RowType.Fields) + { + ordinal++; + var row = table.NewRow(); + row["ColumnName"] = column.Name; + row["ColumnOrdinal"] = ordinal; + row["ColumnSize"] = -1; + row["NumericPrecision"] = 0; + row["NumericScale"] = 0; + row["IsUnique"] = false; + row["IsKey"] = false; + row["BaseServerName"] = ""; + row["BaseCatalogName"] = ""; + row["BaseColumnName"] = ""; + row["BaseSchemaName"] = ""; + row["BaseTableName"] = ""; + row["DataType"] = TypeConversion.GetSystemType(column.Type); + row["AllowDBNull"] = true; + row["ProviderType"] = (int)column.Type.Code; + row["IsAliased"] = false; + row["IsExpression"] = false; + row["IsIdentity"] = false; + row["IsAutoIncrement"] = false; + row["IsRowVersion"] = false; + row["IsHidden"] = false; + row["IsLong"] = false; + row["IsReadOnly"] = false; + row["DataTypeName"] = column.Type.Code.ToString(); + + table.Rows.Add(row); + } + return table; + } + + public override string GetString(int ordinal) + { + var value = GetProtoValue(ordinal); + CheckNotNull(ordinal); + if (value.HasStringValue) + { + return value.StringValue; + } + if (value.HasNumberValue) + { + var type = GetSpannerType(ordinal); + if (type.Code == TypeCode.Float32) + { + return ((float) value.NumberValue).ToString(CultureInfo.InvariantCulture); + } + return value.NumberValue.ToString(CultureInfo.InvariantCulture); + } + if (value.HasBoolValue) + { + return value.BoolValue.ToString(); + } + throw new InvalidCastException("not a valid string value"); + } + + public override bool GetBoolean(int ordinal) + { + var value = GetProtoValue(ordinal); + CheckNotNull(ordinal); + if (value.HasStringValue) + { + try + { + return bool.Parse(value.StringValue); + } + catch (Exception exception) + { + throw new InvalidCastException(exception.Message, exception); + } + } + if (value.HasBoolValue) + { + return value.BoolValue; + } + throw new InvalidCastException("not a valid bool value"); + } + + public override byte GetByte(int ordinal) + { + CheckValidPosition(); + CheckNotNull(ordinal); + throw new InvalidCastException("not a valid byte value"); + } + + public override long GetBytes(int ordinal, long dataOffset, byte[]? buffer, int bufferOffset, int length) + { + CheckValidPosition(); + CheckValidOrdinal(ordinal); + CheckNotNull(ordinal); + var code = LibRows.Metadata!.RowType.Fields[ordinal].Type.Code; + GaxPreconditions.CheckState(Array.Exists([TypeCode.Bytes, TypeCode.Json, TypeCode.String], c => c == code), + "Spanner only supports conversion to byte arrays for columns of type BYTES or STRING."); + Preconditions.CheckIndexRange(bufferOffset, nameof(bufferOffset), 0, buffer?.Length ?? 0); + Preconditions.CheckIndexRange(length, nameof(length), 0, buffer?.Length ?? int.MaxValue); + if (buffer != null) + { + Preconditions.CheckIndexRange(bufferOffset + length, nameof(length), 0, buffer.Length); + } + + byte[] bytes; + if (code == TypeCode.Bytes) + { + bytes = GetFieldValue(ordinal); + } + else + { + var s = GetFieldValue(ordinal); + bytes = Encoding.UTF8.GetBytes(s); + } + if (buffer == null) + { + // Return the length of the value if `buffer` is null: + // https://docs.microsoft.com/en-us/dotnet/api/system.data.idatarecord.getbytes?view=netstandard-2.1#remarks + return bytes.Length; + } + + var copyLength = Math.Min(length, bytes.Length - (int)dataOffset); + if (copyLength < 0) + { + // Read nothing and just return. + return 0; + } + + Array.Copy(bytes, (int)dataOffset, buffer, bufferOffset, copyLength); + return copyLength; + } + + public override char GetChar(int ordinal) + { + var value = GetProtoValue(ordinal); + CheckNotNull(ordinal); + var type = GetSpannerType(ordinal); + if (type.Code != TypeCode.String) + { + throw new InvalidCastException("not a valid char value"); + } + if (value.HasStringValue) + { + if (value.StringValue.Length == 0) + { + throw new InvalidCastException("not a valid char value"); + } + return value.StringValue[0]; + } + throw new InvalidCastException("not a valid char value"); + } + + public override long GetChars(int ordinal, long dataOffset, char[]? buffer, int bufferOffset, int length) + { + var value = GetProtoValue(ordinal); + var code = GetSpannerType(ordinal).Code; + if (!Array.Exists([TypeCode.Bytes, TypeCode.Json, TypeCode.String], c => c == code)) + { + throw new InvalidCastException("not a valid type for getting as chars"); + } + if (value.HasNullValue) + { + return 0; + } + if (buffer == null) + { + // Return the length of the value if `buffer` is null: + // https://docs.microsoft.com/en-us/dotnet/api/system.data.idatarecord.getbytes?view=netstandard-2.1#remarks + return value.StringValue.ToCharArray().Length; + } + Preconditions.CheckIndexRange(bufferOffset, nameof(bufferOffset), 0, buffer.Length); + Preconditions.CheckIndexRange(length, nameof(length), 0, buffer.Length - bufferOffset); + + var intDataOffset = (int)dataOffset; + var sourceLength = Math.Min(length, value.StringValue.Length - intDataOffset); + var destLength = Math.Min(length, buffer.Length - bufferOffset); + destLength = Math.Min(destLength, sourceLength); + + if (destLength <= 0) + { + return 0; + } + if (bufferOffset + destLength > buffer.Length) + { + return 0; + } + + var chars = value.StringValue.ToCharArray(); + if (intDataOffset >= chars.Length) + { + return 0; + } + + Array.Copy(chars, dataOffset, buffer, bufferOffset, destLength); + + return destLength; + } + + public override string GetDataTypeName(int ordinal) + { + CheckValidOrdinal(ordinal); + return GetTypeName(LibRows.Metadata!.RowType.Fields[ordinal].Type); + } + + private static string GetTypeName(Google.Cloud.Spanner.V1.Type type) + { + if (type.Code == TypeCode.Array) + { + return type.Code.GetOriginalName() + "<" + type.ArrayElementType.Code.GetOriginalName() + ">"; + } + return type.Code.GetOriginalName(); + } + + public override DateTime GetDateTime(int ordinal) + { + var value = GetProtoValue(ordinal); + CheckNotNull(ordinal); + var type = GetSpannerType(ordinal); + if (type.Code == TypeCode.Date) + { + var date = DateOnly.Parse(value.StringValue); + return date.ToDateTime(TimeOnly.MinValue); + } + if (value.HasStringValue) + { + try + { + return XmlConvert.ToDateTime(value.StringValue, XmlDateTimeSerializationMode.Utc); + } + catch (Exception exception) + { + throw new InvalidCastException(exception.Message, exception); + } + } + throw new InvalidCastException("not a valid DateTime value"); + } + + public override decimal GetDecimal(int ordinal) + { + var value = GetProtoValue(ordinal); + CheckNotNull(ordinal); + if (value.HasStringValue) + { + try + { + return decimal.Parse(value.StringValue, NumberStyles.AllowDecimalPoint | NumberStyles.AllowLeadingSign | NumberStyles.AllowExponent, CultureInfo.InvariantCulture); + } + catch (Exception exception) + { + throw new InvalidCastException(exception.Message, exception); + } + } + throw new InvalidCastException("not a valid decimal value"); + } + + public override double GetDouble(int ordinal) + { + var value = GetProtoValue(ordinal); + CheckNotNull(ordinal); + if (value.HasStringValue) + { + try + { + return double.Parse(value.StringValue, + NumberStyles.AllowDecimalPoint | NumberStyles.AllowLeadingSign | NumberStyles.AllowExponent, + CultureInfo.InvariantCulture); + } + catch (Exception exception) + { + throw new InvalidCastException(exception.Message, exception); + } + } + if (value.HasNumberValue) + { + return value.NumberValue; + } + throw new InvalidCastException("not a valid double value"); + } + + public override System.Type GetFieldType(int ordinal) + { + CheckValidOrdinal(ordinal); + return GetClrType(LibRows.Metadata!.RowType.Fields[ordinal].Type); + } + + private static System.Type GetClrType(Google.Cloud.Spanner.V1.Type type) + { + return type.Code switch + { + TypeCode.Array => typeof(List<>).MakeGenericType(GetClrType(type.ArrayElementType)), + TypeCode.Bool => typeof(bool), + TypeCode.Bytes => typeof(byte[]), + TypeCode.Date => typeof(DateOnly), + TypeCode.Enum => typeof(int), + TypeCode.Float32 => typeof(float), + TypeCode.Float64 => typeof(double), + TypeCode.Int64 => typeof(long), + TypeCode.Interval => typeof(TimeSpan), + TypeCode.Json => typeof(string), + TypeCode.Numeric => typeof(decimal), + TypeCode.Proto => typeof(byte[]), + TypeCode.String => typeof(string), + TypeCode.Timestamp => typeof(DateTime), + TypeCode.Uuid => typeof(Guid), + _ => typeof(Value) + }; + } + + public override float GetFloat(int ordinal) + { + var value = GetProtoValue(ordinal); + CheckNotNull(ordinal); + if (value.HasStringValue) + { + try + { + return float.Parse(value.StringValue, + NumberStyles.AllowDecimalPoint | NumberStyles.AllowLeadingSign | NumberStyles.AllowExponent, + CultureInfo.InvariantCulture); + } + catch (Exception exception) + { + throw new InvalidCastException(exception.Message, exception); + } + } + var type = GetSpannerType(ordinal); + if (type.Code == TypeCode.Float32) + { + return (float)value.NumberValue; + } + throw new InvalidCastException("not a valid float value"); + } + + public override Guid GetGuid(int ordinal) + { + var value = GetProtoValue(ordinal); + CheckNotNull(ordinal); + if (value.HasStringValue) + { + try + { + return Guid.Parse(value.StringValue); + } + catch (Exception exception) + { + throw new InvalidCastException(exception.Message, exception); + } + } + throw new InvalidCastException("not a valid Guid value"); + } + + public override short GetInt16(int ordinal) + { + var value = GetProtoValue(ordinal); + CheckNotNull(ordinal); + if (value.HasStringValue) + { + try + { + return short.Parse(value.StringValue, + NumberStyles.AllowDecimalPoint | NumberStyles.AllowLeadingSign | NumberStyles.AllowExponent, + CultureInfo.InvariantCulture); + } + catch (OverflowException) + { + throw; + } + catch (Exception exception) + { + throw new InvalidCastException(exception.Message, exception); + } + } + if (value.HasNumberValue) + { + return (short)value.NumberValue; + } + throw new InvalidCastException("not a valid Int16 value"); + } + + public override int GetInt32(int ordinal) + { + var value = GetProtoValue(ordinal); + CheckNotNull(ordinal); + if (value.HasStringValue) + { + try + { + return int.Parse(value.StringValue, + NumberStyles.AllowDecimalPoint | NumberStyles.AllowLeadingSign | NumberStyles.AllowExponent, + CultureInfo.InvariantCulture); + } + catch (OverflowException) + { + throw; + } + catch (Exception exception) + { + throw new InvalidCastException(exception.Message, exception); + } + } + if (value.HasNumberValue) + { + return (int)value.NumberValue; + } + throw new InvalidCastException("not a valid Int32 value"); + } + + public override long GetInt64(int ordinal) + { + var value = GetProtoValue(ordinal); + CheckNotNull(ordinal); + if (value.HasStringValue) + { + try + { + return long.Parse(value.StringValue, + NumberStyles.AllowDecimalPoint | NumberStyles.AllowLeadingSign | NumberStyles.AllowExponent, + CultureInfo.InvariantCulture); + } + catch (Exception exception) + { + throw new InvalidCastException(exception.Message, exception); + } + } + if (value.HasNumberValue) + { + return (long)value.NumberValue; + } + throw new InvalidCastException("not a valid Int64 value"); + } + + public TimeSpan GetTimeSpan(int ordinal) => GetFieldValue(ordinal); + + public override string GetName(int ordinal) + { + CheckValidOrdinal(ordinal); + return LibRows.Metadata!.RowType.Fields[ordinal].Name; + } + + public override int GetOrdinal(string name) + { + CheckNotClosed(); + // First try with case sensitivity. + for (var i = 0; i < LibRows.Metadata?.RowType.Fields.Count; i++) + { + if (Equals(LibRows.Metadata?.RowType.Fields[i].Name, name)) + { + return i; + } + } + // Nothing found, try with case-insensitive comparison. + for (var i = 0; i < LibRows.Metadata?.RowType.Fields.Count; i++) + { + if (string.Equals(LibRows.Metadata?.RowType.Fields[i].Name, name, StringComparison.InvariantCultureIgnoreCase)) + { + return i; + } + } + throw new IndexOutOfRangeException($"No column with name {name} found"); + } + + public override T GetFieldValue(int ordinal) + { + CheckNotClosed(); + CheckValidPosition(); + CheckValidOrdinal(ordinal); + if (typeof(T) == typeof(Stream)) + { + CheckNotNull(ordinal); + return (T)(object)GetStream(ordinal); + } + if (typeof(T) == typeof(TextReader)) + { + CheckNotNull(ordinal); + return (T)(object)GetTextReader(ordinal); + } + if (typeof(T) == typeof(char) || typeof(T) == typeof(char?)) + { + if (IsDBNull(ordinal) && typeof(T) == typeof(char?)) + { + return (T)(object)null!; + } + return (T)(object)GetChar(ordinal); + } + if (typeof(T) == typeof(DateTime) || typeof(T) == typeof(DateTime?)) + { + if (IsDBNull(ordinal) && typeof(T) == typeof(DateTime?)) + { + return (T)(object)null!; + } + return (T)(object)GetDateTime(ordinal); + } + if (typeof(T) == typeof(double) || typeof(T) == typeof(double?)) + { + if (IsDBNull(ordinal) && typeof(T) == typeof(double?)) + { + return (T)(object)null!; + } + return (T)(object)GetDouble(ordinal); + } + if (typeof(T) == typeof(float) || typeof(T) == typeof(float?)) + { + if (IsDBNull(ordinal) && typeof(T) == typeof(float?)) + { + return (T)(object)null!; + } + return (T)(object)GetFloat(ordinal); + } + if (typeof(T) == typeof(Int16) || typeof(T) == typeof(Int16?)) + { + if (IsDBNull(ordinal) && typeof(T) == typeof(Int16?)) + { + return (T)(object)null!; + } + return (T)(object)GetInt16(ordinal); + } + if (typeof(T) == typeof(int) || typeof(T) == typeof(int?)) + { + if (IsDBNull(ordinal) && typeof(T) == typeof(int?)) + { + return (T)(object)null!; + } + return (T)(object)GetInt32(ordinal); + } + if (typeof(T) == typeof(long) || typeof(T) == typeof(long?)) + { + if (IsDBNull(ordinal) && typeof(T) == typeof(long?)) + { + return (T)(object)null!; + } + return (T)(object)GetInt64(ordinal); + } + + return base.GetFieldValue(ordinal); + } + + public override object GetValue(int ordinal) + { + CheckValidOrdinal(ordinal); + CheckValidPosition(); + var type = LibRows.Metadata!.RowType.Fields[ordinal].Type; + var value = _currentRow!.Values[ordinal]; + return GetUnderlyingValue(type, value); + } + + private static object GetUnderlyingValue(Google.Cloud.Spanner.V1.Type type, Value value) + { + if (value.HasNullValue) + { + return DBNull.Value; + } + + switch (type.Code) + { + case TypeCode.Array: + var listType = typeof(List<>).MakeGenericType(GetClrType(type.ArrayElementType)); + var list = (IList)Activator.CreateInstance(listType); + foreach (var element in value.ListValue.Values) + { + list.Add(GetUnderlyingValue(type.ArrayElementType, element)); + } + return list; + case TypeCode.Bool: + return value.BoolValue; + case TypeCode.Bytes: + return Convert.FromBase64String(value.StringValue); + case TypeCode.Date: + return DateOnly.Parse(value.StringValue); + case TypeCode.Enum: + return long.Parse(value.StringValue); + case TypeCode.Float32: + return (float)value.NumberValue; + case TypeCode.Float64: + return value.NumberValue; + case TypeCode.Int64: + return long.Parse(value.StringValue); + case TypeCode.Interval: + return XmlConvert.ToTimeSpan(value.StringValue); + case TypeCode.Json: + return value.StringValue; + case TypeCode.Numeric: + return decimal.Parse(value.StringValue, NumberStyles.AllowDecimalPoint | NumberStyles.AllowExponent | NumberStyles.AllowLeadingSign, CultureInfo.InvariantCulture); + case TypeCode.Proto: + return Convert.FromBase64String(value.StringValue); + case TypeCode.String: + return value.StringValue; + case TypeCode.Timestamp: + return XmlConvert.ToDateTime(value.StringValue, XmlDateTimeSerializationMode.Utc); + case TypeCode.Uuid: + return Guid.Parse(value.StringValue); + } + if (value.HasBoolValue) + { + return value.BoolValue; + } + if (value.HasStringValue) + { + return value.StringValue; + } + if (value.HasNumberValue) + { + return value.NumberValue; + } + return value; + } + + private Value GetProtoValue(int ordinal) + { + CheckValidOrdinal(ordinal); + CheckValidPosition(); + return _currentRow!.Values[ordinal]; + } + + private V1.Type GetSpannerType(int ordinal) + { + CheckValidOrdinal(ordinal); + return LibRows.Metadata?.RowType.Fields[ordinal].Type ?? throw new DataException("metadata not found"); + } + + public override int GetValues(object[] values) + { + CheckValidPosition(); + GaxPreconditions.CheckNotNull(values, nameof(values)); + + var count = Math.Min(FieldCount, values.Length); + for (var i = 0; i < count; i++) + { + values[i] = this[i]; + } + + return count; + } + + public override bool IsDBNull(int ordinal) + { + var value = GetProtoValue(ordinal); + return value.HasNullValue; + } + + public override bool NextResult() + { + CheckNotClosed(); + return false; + } + + public override IEnumerator GetEnumerator() + { + CheckNotClosed(); + return new DbEnumerator(this); + } + + private void CheckValidPosition() + { + CheckNotClosed(); + if (_currentRow == null) + { + throw new InvalidOperationException("DataReader is before the first row or after the last row"); + } + } + + private void CheckValidOrdinal(int ordinal) + { + CheckNotClosed(); + var metadata = LibRows.Metadata; + GaxPreconditions.CheckState(metadata != null && metadata.RowType.Fields.Count > 0, "This reader does not contain any rows"); + + // Check that the ordinal is within the range of the columns in the query. + if (ordinal < 0 || ordinal >= metadata!.RowType.Fields.Count) + { + throw new IndexOutOfRangeException("ordinal is out of range"); + } + } + + private void CheckNotNull(int ordinal) + { + if (_currentRow?.Values[ordinal]?.HasNullValue ?? false) + { + throw new InvalidCastException("Value is null"); + } + } + +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net/SpannerDataSource.cs b/drivers/spanner-ado-net/spanner-ado-net/SpannerDataSource.cs new file mode 100644 index 00000000..34bc693e --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/SpannerDataSource.cs @@ -0,0 +1,72 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Data.Common; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; +using Google.Api.Gax; + +namespace Google.Cloud.Spanner.DataProvider; + +public class SpannerDataSource : DbDataSource +{ + private readonly SpannerConnectionStringBuilder _connectionStringBuilder; + + [AllowNull] + public sealed override string ConnectionString => _connectionStringBuilder.ConnectionString; + + public static SpannerDataSource Create(string connectionString) + { + GaxPreconditions.CheckNotNull(connectionString, nameof(connectionString)); + return Create(new SpannerConnectionStringBuilder(connectionString)); + } + + public static SpannerDataSource Create(SpannerConnectionStringBuilder connectionStringBuilder) + { + return new SpannerDataSource(connectionStringBuilder); + } + + private SpannerDataSource(SpannerConnectionStringBuilder connectionStringBuilder) + { + GaxPreconditions.CheckNotNull(connectionStringBuilder, nameof(connectionStringBuilder)); + connectionStringBuilder.CheckValid(); + _connectionStringBuilder = connectionStringBuilder; + } + + public new SpannerConnection CreateConnection() => (base.CreateConnection() as SpannerConnection)!; + + public new SpannerConnection OpenConnection() => (base.OpenDbConnection() as SpannerConnection)!; + + public new async ValueTask OpenConnectionAsync(CancellationToken cancellationToken = default) + { + var connection = CreateConnection(); + try + { + await connection.OpenAsync(cancellationToken).ConfigureAwait(false); + return connection; + } + catch + { + await connection.DisposeAsync().ConfigureAwait(false); + throw; + } + } + + protected override DbConnection CreateDbConnection() + { + return new SpannerConnection(_connectionStringBuilder); + } +} diff --git a/drivers/spanner-ado-net/spanner-ado-net/SpannerDbException.cs b/drivers/spanner-ado-net/spanner-ado-net/SpannerDbException.cs new file mode 100644 index 00000000..619015b8 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/SpannerDbException.cs @@ -0,0 +1,54 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Data.Common; +using Google.Cloud.SpannerLib; +using Google.Rpc; + +namespace Google.Cloud.Spanner.DataProvider; + +public class SpannerDbException : DbException +{ + internal static T TranslateException(Func func) + { + try + { + return func(); + } + catch (SpannerException exception) + { + throw TranslateException(exception); + } + } + + internal static Exception TranslateException(SpannerException exception) + { + if (exception.Code == Code.Cancelled) + { + return new OperationCanceledException(exception.Message, exception); + } + return new SpannerDbException(exception); + } + + private SpannerException SpannerException { get; } + + public Status Status => SpannerException.Status; + + internal SpannerDbException(SpannerException spannerException) : base(spannerException.Message, spannerException) + { + SpannerException = spannerException; + } + +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net/SpannerFactory.cs b/drivers/spanner-ado-net/spanner-ado-net/SpannerFactory.cs new file mode 100644 index 00000000..6bfe1967 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/SpannerFactory.cs @@ -0,0 +1,94 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Data.Common; + +namespace Google.Cloud.Spanner.DataProvider; + +public class SpannerFactory : DbProviderFactory, IServiceProvider +{ + /// + /// Gets an instance of the . + /// This can be used to retrieve strongly typed data objects. + /// + public static readonly SpannerFactory Instance = new(); + + SpannerFactory() {} + + /// + /// Returns a strongly typed instance. + /// + public override DbCommand CreateCommand() => new SpannerCommand(); + + /// + /// Returns a strongly typed instance. + /// + public override DbConnection CreateConnection() => new SpannerConnection(); + + /// + /// Returns a strongly typed instance. + /// + public override DbParameter CreateParameter() => new SpannerParameter(); + + /// + /// Returns a strongly typed instance. + /// + public override DbConnectionStringBuilder CreateConnectionStringBuilder() => new SpannerConnectionStringBuilder(); + + /// + /// Returns a strongly typed instance. + /// + public override DbCommandBuilder CreateCommandBuilder() => new SpannerCommandBuilder(); + + /// + /// Returns a strongly typed instance. + /// + public override DbDataAdapter CreateDataAdapter() => new SpannerDataAdapter(); + + /// + /// Specifies whether the specific supports the class. + /// + public override bool CanCreateDataAdapter => true; + + /// + /// Specifies whether the specific supports the class. + /// + public override bool CanCreateCommandBuilder => true; + + /// + public override bool CanCreateBatch => true; + + /// + public override DbBatch CreateBatch() => new SpannerBatch(); + + /// + public override DbBatchCommand CreateBatchCommand() => new SpannerBatchCommand(); + + /// + public override DbDataSource CreateDataSource(string connectionString) + => SpannerDataSource.Create(connectionString); + + #region IServiceProvider Members + + /// + /// Gets the service object of the specified type. + /// + /// An object that specifies the type of service object to get. + /// A service object of type serviceType, or null if there is no service object of type serviceType. + public object? GetService(System.Type serviceType) => null; + + #endregion + +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net/SpannerParameter.cs b/drivers/spanner-ado-net/spanner-ado-net/SpannerParameter.cs new file mode 100644 index 00000000..f0a19685 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/SpannerParameter.cs @@ -0,0 +1,218 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections; +using System.Data; +using System.Data.Common; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using System.IO; +using System.Text.Json; +using System.Text.RegularExpressions; +using System.Xml; +using Google.Api.Gax; +using Google.Cloud.Spanner.V1; +using Google.Protobuf.WellKnownTypes; +using TypeCode = Google.Cloud.Spanner.V1.TypeCode; + +namespace Google.Cloud.Spanner.DataProvider; + +public class SpannerParameter : DbParameter, IDbDataParameter, ICloneable +{ + private DbType? _dbType; + + public override DbType DbType + { + get => _dbType ?? DbType.String; + set => _dbType = value; + } + + /// + /// SpannerParameterType overrides the standard DbType property with a specific Spanner type. + /// Use this property if you need to set a specific Spanner type that is not supported by DbType, such as + /// one of the Spanner array types. + /// + public V1.Type? SpannerParameterType { get; set; } + + public override ParameterDirection Direction { get; set; } = ParameterDirection.Input; + public override bool IsNullable { get; set; } + + public new byte Precision { get; set; } + + public new byte Scale { get; set; } + + private string _name = ""; + [AllowNull] public override string ParameterName + { + get => _name; + set => _name = value ?? ""; + } + + private string _sourceColumn = ""; + + [AllowNull] + public override string SourceColumn + { + get => _sourceColumn; + set => _sourceColumn = value ?? ""; + } + public sealed override object? Value { get; set; } + public override bool SourceColumnNullMapping { get; set; } + public override int Size { get; set; } + public override DataRowVersion SourceVersion + { + get => DataRowVersion.Current; + set { } + } + + public SpannerParameter() { } + + public SpannerParameter(string name, object? value) + { + GaxPreconditions.CheckNotNull(name, nameof(name)); + _name = name; + Value = value; + } + + public override void ResetDbType() + { + _dbType = null; + } + + internal Value ConvertToProto(DbParameter dbParameter) + { + GaxPreconditions.CheckState(dbParameter.Direction != ParameterDirection.Input || Value != null, $"Parameter {ParameterName} has no value"); + return ConvertToProto(Value); + } + + internal Google.Cloud.Spanner.V1.Type? GetSpannerType() + { + return SpannerParameterType ?? TypeConversion.GetSpannerType(_dbType); + } + + private Value ConvertToProto(object? value) + { + var type = GetSpannerType(); + return ConvertToProto(value, type); + } + + private static Value ConvertToProto(object? value, Google.Cloud.Spanner.V1.Type? type) + { + var proto = new Value(); + switch (value) + { + case null: + case DBNull: + proto.NullValue = NullValue.NullValue; + break; + case bool b: + proto.BoolValue = b; + break; + case double d: + proto.NumberValue = d; + break; + case float f: + proto.NumberValue = f; + break; + case string str: + proto.StringValue = str; + break; + case Regex regex: + proto.StringValue = regex.ToString(); + break; + case byte b: + proto.StringValue = b.ToString(); + break; + case byte[] bytes: + proto.StringValue = Convert.ToBase64String(bytes); + break; + case MemoryStream memoryStream: + // TODO: Optimize this + proto.StringValue = Convert.ToBase64String(memoryStream.ToArray()); + break; + case short s: + proto.StringValue = s.ToString(); + break; + case int i: + proto.StringValue = i.ToString(); + break; + case long l: + proto.StringValue = l.ToString(); + break; + case decimal d: + proto.StringValue = d.ToString(CultureInfo.InvariantCulture); + break; + case SpannerNumeric num: + proto.StringValue = num.ToString(); + break; + case DateOnly d: + proto.StringValue = d.ToString("O"); + break; + case SpannerDate d: + proto.StringValue = d.ToString(); + break; + case DateTime d: + // Some framework pass DATE values as DateTime. + if (type?.Code == TypeCode.Date) + { + proto.StringValue = d.Date.ToString("yyyy-MM-dd"); + } + else + { + proto.StringValue = d.ToUniversalTime().ToString("O"); + } + break; + case TimeSpan t: + proto.StringValue = XmlConvert.ToString(t); + break; + case JsonDocument jd: + proto.StringValue = jd.RootElement.ToString(); + break; + case IEnumerable list: + var elementType = type?.ArrayElementType; + proto.ListValue = new ListValue(); + foreach (var item in list) + { + proto.ListValue.Values.Add(ConvertToProto(item, elementType)); + } + break; + default: + // Unknown type. Just try to send it as a string. + proto.StringValue = value.ToString(); + break; + } + return proto; + } + + object ICloneable.Clone() => Clone(); + + public SpannerParameter Clone() + { + var clone = new SpannerParameter(_name, Value) + { + _dbType = _dbType, + Direction = Direction, + IsNullable = IsNullable, + Precision = Precision, + Scale = Scale, + Size = Size, + SourceColumn = SourceColumn, + SourceVersion = SourceVersion, + SourceColumnNullMapping = SourceColumnNullMapping, + }; + return clone; + } + +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net/SpannerParameterCollection.cs b/drivers/spanner-ado-net/spanner-ado-net/SpannerParameterCollection.cs new file mode 100644 index 00000000..26691ae5 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/SpannerParameterCollection.cs @@ -0,0 +1,274 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Data.Common; +using Google.Api.Gax; +using Google.Protobuf.Collections; +using Google.Protobuf.WellKnownTypes; + +namespace Google.Cloud.Spanner.DataProvider; + +public class SpannerParameterCollection : DbParameterCollection +{ + private readonly List _params = new (); + public override int Count => _params.Count; + public override object SyncRoot => _params; + + public override int Add(object value) + { + GaxPreconditions.CheckNotNull(value, nameof(value)); + var index = _params.Count; + if (value is SpannerParameter spannerParameter) + { + _params.Add(spannerParameter); + } + else if (value is DbParameter) + { + throw new ArgumentException("value is not a SpannerParameter"); + } + else + { + _params.Add(new SpannerParameter { ParameterName = "p" + (index + 1), Value = value }); + } + + return index; + } + + public SpannerParameter AddWithValue(string parameterName, object? value) + { + var parameter = new SpannerParameter + { + ParameterName = parameterName, + Value = value, + }; + Add(parameter); + return parameter; + } + + public override void Clear() + { + _params.Clear(); + } + + public override bool Contains(object value) + { + return IndexOf(value) > -1; + } + + public override int IndexOf(object value) + { + if (value is SpannerParameter spannerParameter) + { + return _params.IndexOf(spannerParameter); + } + return _params.FindIndex(p => Equals(p.Value, value)); + } + + public override void Insert(int index, object value) + { + GaxPreconditions.CheckNotNull(value, nameof(value)); + if (value is SpannerParameter spannerParameter) + { + _params.Insert(index, spannerParameter); + } + else if (value is DbParameter) + { + throw new ArgumentException("value is not a SpannerParameter"); + } + else + { + _params.Insert(index, new SpannerParameter { ParameterName = "p" + (index + 1), Value = value }); + } + } + + public override void Remove(object value) + { + GaxPreconditions.CheckNotNull(value, nameof(value)); + var index = IndexOf(value); + if (index > -1) + { + _params.RemoveAt(index); + } + } + + public override void RemoveAt(int index) + { + _params.RemoveAt(index); + } + + public override void RemoveAt(string parameterName) + { + var index = IndexOf(parameterName); + if (index > -1) + { + _params.RemoveAt(index); + } + } + + protected override void SetParameter(int index, DbParameter value) + { + GaxPreconditions.CheckNotNull(value, nameof(value)); + if (value is SpannerParameter spannerParameter) + { + _params[index] = spannerParameter; + } + else + { + throw new ArgumentException("value is not a SpannerParameter"); + } + } + + protected override void SetParameter(string parameterName, DbParameter value) + { + GaxPreconditions.CheckNotNull(value, nameof(value)); + if (value is SpannerParameter spannerParameter) + { + if (spannerParameter.ParameterName == "") + { + spannerParameter.ParameterName = parameterName; + } + else if (!spannerParameter.ParameterName.Equals(parameterName, StringComparison.OrdinalIgnoreCase)) + { + throw new ArgumentException("Parameter names mismatch"); + } + var index = IndexOf(parameterName); + if (index > -1) + { + _params[index] = spannerParameter; + } + else + { + spannerParameter.ParameterName = parameterName; + Add(spannerParameter); + } + } + else + { + throw new ArgumentException("value is not a SpannerParameter"); + } + } + + public override int IndexOf(string parameterName) + { + var result = _params.FindIndex(p => Equals(p.ParameterName, parameterName)); + if (result > -1) + { + return result; + } + return _params.FindIndex(p => p.ParameterName.Equals(parameterName, StringComparison.OrdinalIgnoreCase)); + } + + public override bool Contains(string value) + { + return IndexOf(value) > -1; + } + + public override void CopyTo(Array array, int index) + { + if (array == null) + { + throw new ArgumentNullException(nameof(array)); + } + + if (array.Length < _params.Count + index) + { + throw new ArgumentOutOfRangeException( + nameof(array), "There is not enough space in the array to copy values."); + } + + foreach (var item in _params) + { + array.SetValue(item, index); + index++; + } + } + + public override IEnumerator GetEnumerator() + { + return _params.GetEnumerator(); + } + + protected override DbParameter GetParameter(int index) + { + return _params[index]; + } + +#pragma warning disable CS8764 + protected override DbParameter? GetParameter(string parameterName) +#pragma warning restore CS8764 + { + var index = IndexOf(parameterName); + return index > -1 ? _params[index] : null; + } + + public override void AddRange(Array values) + { + foreach (var value in values) + { + Add(value); + } + } + + internal Tuple> CreateSpannerParams() + { + var queryParams = new Struct(); + var paramTypes = new MapField(); + for (var index = 0; index < Count; index++) + { + var param = this[index]; + if (param is SpannerParameter spannerParameter) + { + var name = param.ParameterName; + if (name.StartsWith("@")) + { + name = name[1..]; + } + else if (name.StartsWith("$")) + { + name = "p" + name[1..]; + } + else if (string.IsNullOrEmpty(name)) + { + name = "p" + (index + 1); + } + queryParams.Fields.Add(name, spannerParameter.ConvertToProto(spannerParameter)); + var paramType = spannerParameter.GetSpannerType(); + if (paramType != null) + { + paramTypes.Add(name, paramType); + } + } + else + { + throw new InvalidOperationException("parameter is not a SpannerParameter: " + param.ParameterName); + } + } + return Tuple.Create(queryParams, paramTypes); + } + + internal void CloneTo(SpannerParameterCollection other) + { + GaxPreconditions.CheckNotNull(other, nameof(other)); + other._params.Clear(); + foreach (var param in _params) + { + var newParam = param.Clone(); + other._params.Add(newParam); + } + } + +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net/SpannerPool.cs b/drivers/spanner-ado-net/spanner-ado-net/SpannerPool.cs new file mode 100644 index 00000000..6d401971 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/SpannerPool.cs @@ -0,0 +1,97 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using Google.Cloud.SpannerLib; +using Google.Cloud.SpannerLib.Grpc; +using Google.Cloud.SpannerLib.Native.Impl; + +namespace Google.Cloud.Spanner.DataProvider; + +internal class SpannerPool +{ + private static ISpannerLib? _gRpcSpannerLib; + + private static ISpannerLib GrpcSpannerLib + { + get + { + _gRpcSpannerLib ??= new GrpcLibSpanner(); + return _gRpcSpannerLib; + } + } + + private static ISpannerLib? _nativeSpannerLib; + + private static ISpannerLib NativeSpannerLib + { + get + { + _nativeSpannerLib ??= new SharedLibSpanner(); + return _nativeSpannerLib; + } + } + + private static readonly ConcurrentDictionary Pools = new(); + + [MethodImpl(MethodImplOptions.Synchronized)] + internal static SpannerPool GetOrCreate(string dsn, bool useNativeLibrary = false) + { + if (Pools.TryGetValue(dsn, out var value)) + { + return value; + } + var pool = Pool.Create(useNativeLibrary ? NativeSpannerLib : GrpcSpannerLib, dsn); + var spannerPool = new SpannerPool(dsn, pool); + Pools[dsn] = spannerPool; + return spannerPool; + } + + [MethodImpl(MethodImplOptions.Synchronized)] + internal static void CloseSpannerLib() + { + foreach (var pool in Pools.Values) + { + pool.Close(); + } + Pools.Clear(); + GrpcSpannerLib.Dispose(); + _gRpcSpannerLib = null; + NativeSpannerLib.Dispose(); + _nativeSpannerLib = null; + } + + private readonly string _dsn; + + private readonly Pool _libPool; + + private SpannerPool(string dsn, Pool libPool) + { + _dsn = dsn; + _libPool = libPool; + } + + internal void Close() + { + _libPool.Close(); + Pools.Remove(_dsn, out _); + } + + internal Connection CreateConnection() + { + return _libPool.CreateConnection(); + } +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net/SpannerSchemaProvider.cs b/drivers/spanner-ado-net/spanner-ado-net/SpannerSchemaProvider.cs new file mode 100644 index 00000000..73ecdc32 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/SpannerSchemaProvider.cs @@ -0,0 +1,359 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Data; +using System.Data.Common; +using Google.Api.Gax; +using TypeCode = Google.Cloud.Spanner.V1.TypeCode; + +namespace Google.Cloud.Spanner.DataProvider; + +internal sealed class SpannerSchemaProvider(SpannerConnection connection) +{ + internal DataTable GetSchema() => GetSchema("MetaDataCollections", null); + + internal DataTable GetSchema(string collectionName, string?[]? restrictionValues) + { + GaxPreconditions.CheckNotNull(collectionName, nameof(collectionName)); + var dataTable = new DataTable(); + if (string.Equals(collectionName, DbMetaDataCollectionNames.MetaDataCollections, StringComparison.OrdinalIgnoreCase)) + { + FillMetaDataCollections(dataTable, restrictionValues); + } + else if (string.Equals(collectionName, DbMetaDataCollectionNames.DataSourceInformation, StringComparison.OrdinalIgnoreCase)) + { + FillDataSourceInformation(dataTable, restrictionValues); + } + else if (string.Equals(collectionName, DbMetaDataCollectionNames.DataTypes, StringComparison.OrdinalIgnoreCase)) + { + FillDataTypes(dataTable, restrictionValues); + } + else if (string.Equals(collectionName, DbMetaDataCollectionNames.ReservedWords, StringComparison.OrdinalIgnoreCase)) + { + FillReservedWords(dataTable, restrictionValues); + } + else if (string.Equals(collectionName, DbMetaDataCollectionNames.Restrictions, StringComparison.OrdinalIgnoreCase)) + { + FillRestrictions(dataTable, restrictionValues); + } + else + { + throw new ArgumentException($"Invalid collection name: '{collectionName}'.", nameof(collectionName)); + } + return dataTable; + } + + private void FillMetaDataCollections(DataTable dataTable, string?[]? restrictionValues) + { + GaxPreconditions.CheckArgument(restrictionValues == null || restrictionValues.Length == 0, nameof(restrictionValues), "restrictionValues is not supported for schema 'MetaDataCollections'."); + + dataTable.TableName = DbMetaDataCollectionNames.MetaDataCollections; + dataTable.Columns.AddRange( + [ + new("CollectionName", typeof(string)), + new("NumberOfRestrictions", typeof(int)), + new("NumberOfIdentifierParts", typeof(int)), + ]); + + dataTable.Rows.Add(DbMetaDataCollectionNames.MetaDataCollections, 0, 0); + dataTable.Rows.Add(DbMetaDataCollectionNames.DataSourceInformation, 0, 0); + dataTable.Rows.Add(DbMetaDataCollectionNames.DataTypes, 0, 0); + dataTable.Rows.Add(DbMetaDataCollectionNames.ReservedWords, 0, 0); + dataTable.Rows.Add(DbMetaDataCollectionNames.Restrictions, 0, 0); + } + + private void FillDataSourceInformation(DataTable dataTable, string?[]? restrictionValues) + { + GaxPreconditions.CheckArgument(restrictionValues == null || restrictionValues.Length == 0, nameof(restrictionValues), "restrictionValues is not supported for schema 'DataSourceInformation'."); + + dataTable.TableName = DbMetaDataCollectionNames.DataSourceInformation; + dataTable.Columns.AddRange( + [ + new("CompositeIdentifierSeparatorPattern", typeof(string)), + new("DataSourceProductName", typeof(string)), + new("DataSourceProductVersion", typeof(string)), + new("DataSourceProductVersionNormalized", typeof(string)), + new("GroupByBehavior", typeof(GroupByBehavior)), + new("IdentifierPattern", typeof(string)), + new("IdentifierCase", typeof(IdentifierCase)), + new("OrderByColumnsInSelect", typeof(bool)), + new("ParameterMarkerFormat", typeof(string)), + new("ParameterMarkerPattern", typeof(string)), + new("ParameterNameMaxLength", typeof(int)), + new("QuotedIdentifierPattern", typeof(string)), + new("QuotedIdentifierCase", typeof(IdentifierCase)), + new("ParameterNamePattern", typeof(string)), + new("StatementSeparatorPattern", typeof(string)), + new("StringLiteralPattern", typeof(string)), + new("SupportedJoinOperators", typeof(SupportedJoinOperators)), + ]); + + var row = dataTable.NewRow(); + row["CompositeIdentifierSeparatorPattern"] = @"\."; + row["DataSourceProductName"] = "Spanner"; + row["DataSourceProductVersion"] = connection.ServerVersion; + row["DataSourceProductVersionNormalized"] = connection.ServerVersionNormalizedString; + row["GroupByBehavior"] = GroupByBehavior.Unrelated; + row["IdentifierPattern"] = @"(^\[\p{Lo}\p{Lu}\p{Ll}][\p{Lo}\p{Lu}\p{Ll}\p{Nd}_]*$)"; + row["IdentifierCase"] = IdentifierCase.Insensitive; + row["OrderByColumnsInSelect"] = false; + row["ParameterMarkerFormat"] = "{0}"; + row["ParameterMarkerPattern"] = "(@[A-Za-z0-9_]*)"; + row["ParameterNameMaxLength"] = 128; + row["QuotedIdentifierPattern"] = @"(([^`]|\\`)+)"; + row["QuotedIdentifierCase"] = IdentifierCase.Insensitive; + row["ParameterNamePattern"] = @"[\p{Lo}\p{Lu}\p{Ll}\p{Lm}][\p{Lo}\p{Lu}\p{Ll}\p{Lm}\p{Nd}_]*"; + row["StatementSeparatorPattern"] = ";"; + row["StringLiteralPattern"] = @"'(([^']|'')*)'"; + row["SupportedJoinOperators"] = + SupportedJoinOperators.FullOuter | + SupportedJoinOperators.Inner | + SupportedJoinOperators.LeftOuter | + SupportedJoinOperators.RightOuter; + dataTable.Rows.Add(row); + } + + private void FillDataTypes(DataTable dataTable, string?[]? restrictionValues) + { + GaxPreconditions.CheckArgument(restrictionValues == null || restrictionValues.Length == 0, nameof(restrictionValues), "restrictionValues is not supported for schema 'DataTypes'."); + + dataTable.TableName = DbMetaDataCollectionNames.DataTypes; + dataTable.Columns.AddRange( + [ + new("TypeName", typeof(string)), + new("ProviderDbType", typeof(int)), + new("ColumnSize", typeof(long)), + new("CreateFormat", typeof(string)), + new("CreateParameters", typeof(string)), + new("DataType", typeof(string)), + new("IsAutoIncrementable", typeof(bool)), + new("IsBestMatch", typeof(bool)), + new("IsCaseSensitive", typeof(bool)), + new("IsFixedLength", typeof(bool)), + new("IsFixedPrecisionScale", typeof(bool)), + new("IsLong", typeof(bool)), + new("IsNullable", typeof(bool)), + new("IsSearchable", typeof(bool)), + new("IsSearchableWithLike", typeof(bool)), + new("IsUnsigned", typeof(bool)), + new("MaximumScale", typeof(short)), + new("MinimumScale", typeof(short)), + new("IsConcurrencyType", typeof(bool)), + new("IsLiteralSupported", typeof(bool)), + new("LiteralPrefix", typeof(string)), + new("LiteralSuffix", typeof(string)), + new("NativeDataType", typeof(string)), + ]); + + foreach (var code in Enum.GetValues()) + { + // These are not supported as stored column types. + if (code == TypeCode.Unspecified || code == TypeCode.Struct) + { + continue; + } + // TODO: Add arrays + if (code == TypeCode.Array) + { + continue; + } + + var clrType = TypeConversion.GetSystemType(code); + var clrTypeName = clrType.ToString(); + // Both STRING and JSON are mapped to System.String. STRING is the best match. + // Both ENUM and INT64 are mapped to System.Int64. INT64 is the best match. + // Both PROTO and BYTES are mapped to System.Byte[]. BYTES is the best match. + var isBestMatch = code != TypeCode.Json && code != TypeCode.Enum && code != TypeCode.Proto; + var dataTypeName = code.ToString(); + var isAutoIncrementable = code == TypeCode.Int64; + var isFixedLength = code != TypeCode.String && code != TypeCode.Bytes && code != TypeCode.Json && code != TypeCode.Proto; + var createFormat = isFixedLength + ? dataTypeName + : dataTypeName + "({0})"; + var createParameters = isFixedLength ? "" : "length"; + var isFixedPrecisionScale = isFixedLength; + var isLong = false; + var columnSize = 0; + var isCaseSensitive = code == TypeCode.String || code == TypeCode.Json; + var isSearchableWithLike = code == TypeCode.String; + object isUnsigned = code == TypeCode.Int64 || code == TypeCode.Float32 || code == TypeCode.Float64 || + code == TypeCode.Numeric ? false : DBNull.Value; + var literalPrefix = $" {code.ToString()}"; + + var row = dataTable.NewRow(); + row["TypeName"] = dataTypeName; + row["ProviderDbType"] = (int) code; + row["ColumnSize"] = columnSize; + row["CreateFormat"] = createFormat; + row["CreateParameters"] = createParameters; + row["DataType"] = clrTypeName; + row["IsAutoIncrementable"] = isAutoIncrementable; + row["IsBestMatch"] = isBestMatch; + row["IsCaseSensitive"] = isCaseSensitive; + row["IsFixedLength"] = isFixedLength; + row["IsFixedPrecisionScale"] = isFixedPrecisionScale; + row["IsLong"] = isLong; + row["IsNullable"] = true; + row["IsSearchable"] = true; + row["IsSearchableWithLike"] = isSearchableWithLike; + row["IsUnsigned"] = isUnsigned; + row["MaximumScale"] = DBNull.Value; + row["MinimumScale"] = DBNull.Value; + row["IsConcurrencyType"] = false; + row["IsLiteralSupported"] = true; + row["LiteralPrefix"] = literalPrefix; + row["LiteralSuffix"] = DBNull.Value; + row["NativeDataType"] = DBNull.Value; + + dataTable.Rows.Add(row); + } + } + private static void FillRestrictions(DataTable dataTable, string?[]? restrictionValues) + { + GaxPreconditions.CheckArgument(restrictionValues == null || restrictionValues.Length == 0, nameof(restrictionValues), "restrictionValues is not supported for schema 'Restrictions'."); + + dataTable.TableName = DbMetaDataCollectionNames.Restrictions; + dataTable.Columns.AddRange( + [ + new("CollectionName", typeof(string)), + new("RestrictionName", typeof(string)), + new("RestrictionDefault", typeof(string)), + new("RestrictionNumber", typeof(int)), + ]); + + dataTable.Rows.Add("Columns", "Catalog", "TABLE_CATALOG", 1); + dataTable.Rows.Add("Columns", "Schema", "TABLE_SCHEMA", 2); + dataTable.Rows.Add("Columns", "Table", "TABLE_NAME", 3); + dataTable.Rows.Add("Columns", "Column", "COLUMN_NAME", 4); + dataTable.Rows.Add("Tables", "Catalog", "TABLE_CATALOG", 1); + dataTable.Rows.Add("Tables", "Schema", "TABLE_SCHEMA", 2); + dataTable.Rows.Add("Tables", "Table", "TABLE_NAME", 3); + dataTable.Rows.Add("Tables", "TableType", "TABLE_TYPE", 4); + dataTable.Rows.Add("Foreign Keys", "Catalog", "TABLE_CATALOG", 1); + dataTable.Rows.Add("Foreign Keys", "Schema", "TABLE_SCHEMA", 2); + dataTable.Rows.Add("Foreign Keys", "Table", "TABLE_NAME", 3); + dataTable.Rows.Add("Foreign Keys", "Constraint Name", "CONSTRAINT_NAME", 4); + dataTable.Rows.Add("Indexes", "Catalog", "TABLE_CATALOG", 1); + dataTable.Rows.Add("Indexes", "Schema", "TABLE_SCHEMA", 2); + dataTable.Rows.Add("Indexes", "Table", "TABLE_NAME", 3); + dataTable.Rows.Add("Indexes", "Name", "INDEX_NAME", 4); + dataTable.Rows.Add("IndexColumns", "Catalog", "TABLE_CATALOG", 1); + dataTable.Rows.Add("IndexColumns", "Schema", "TABLE_SCHEMA", 2); + dataTable.Rows.Add("IndexColumns", "Table", "TABLE_NAME", 3); + dataTable.Rows.Add("IndexColumns", "Name", "INDEX_NAME", 4); + dataTable.Rows.Add("IndexColumns", "Column", "COLUMN_NAME", 5); + } + + private static void FillReservedWords(DataTable dataTable, string?[]? restrictionValues) + { + GaxPreconditions.CheckArgument(restrictionValues == null || restrictionValues.Length == 0, nameof(restrictionValues), "restrictionValues is not supported for schema 'ReservedWords'."); + + dataTable.TableName = DbMetaDataCollectionNames.ReservedWords; + dataTable.Columns.AddRange( + [ + new("ReservedWord", typeof(string)), + ]); + + var keywords = new [] + { + "ALL", + "AND", + "ANY", + "ARRAY", + "AS", + "ASC", + "AT", + "BETWEEN", + "BY", + "CASE", + "CAST", + "CHECK", + "COLUMN", + "COMMIT", + "CONSTRAINT", + "CREATE", + "CROSS", + "CUBE", + "CURRENT", + "DEFAULT", + "DELETE", + "DESC", + "DESCENDING", + "DISTINCT", + "DROP", + "ELSE", + "END", + "ESCAPE", + "EXCEPT", + "EXISTS", + "FALSE", + "FETCH", + "FOLLOWING", + "FOR", + "FOREIGN", + "FROM", + "FULL", + "GROUP", + "GROUPING", + "HAVING", + "IN", + "INNER", + "INSERT", + "INTERSECT", + "INTERVAL", + "INTO", + "IS", + "JOIN", + "LEFT", + "LIKE", + "LIMIT", + "NOT", + "NULL", + "ON", + "OR", + "ORDER", + "OUTER", + "PARTITION", + "PRECEDING", + "PRIMARY", + "REFERENCES", + "RIGHT", + "ROLLUP", + "ROW", + "ROWS", + "SELECT", + "SET", + "SOME", + "TABLE", + "THEN", + "TO", + "TRUE", + "UNBOUNDED", + "UNION", + "UNNEST", + "UPDATE", + "USING", + "VALUES", + "WHEN", + "WHERE", + "WINDOW", + "WITH" + }; + foreach (string word in keywords) + { + dataTable.Rows.Add(word); + } + } +} diff --git a/drivers/spanner-ado-net/spanner-ado-net/SpannerTransaction.cs b/drivers/spanner-ado-net/spanner-ado-net/SpannerTransaction.cs new file mode 100644 index 00000000..b283b980 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/SpannerTransaction.cs @@ -0,0 +1,151 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Data; +using System.Data.Common; +using System.Threading; +using System.Threading.Tasks; +using Google.Api.Gax; +using Google.Cloud.Spanner.V1; + +namespace Google.Cloud.Spanner.DataProvider; + +public class SpannerTransaction : DbTransaction +{ + private SpannerConnection? _spannerConnection; + + protected override DbConnection? DbConnection => _spannerConnection; + public override IsolationLevel IsolationLevel { get; } + private SpannerLib.Connection LibConnection { get; } + + internal bool IsCompleted => _spannerConnection == null; + + private bool _disposed; + + internal SpannerTransaction(SpannerConnection connection, TransactionOptions options) + { + _spannerConnection = connection; + IsolationLevel = TranslateIsolationLevel(options.IsolationLevel); + LibConnection = connection.LibConnection!; + LibConnection.BeginTransaction(options); + } + + internal static TransactionOptions.Types.IsolationLevel TranslateIsolationLevel(IsolationLevel isolationLevel) + { + return isolationLevel switch + { + IsolationLevel.Chaos => throw new NotSupportedException(), + IsolationLevel.ReadUncommitted => throw new NotSupportedException(), + IsolationLevel.ReadCommitted => throw new NotSupportedException(), + IsolationLevel.RepeatableRead => TransactionOptions.Types.IsolationLevel.RepeatableRead, + IsolationLevel.Snapshot => TransactionOptions.Types.IsolationLevel.RepeatableRead, + IsolationLevel.Serializable => TransactionOptions.Types.IsolationLevel.Serializable, + _ => TransactionOptions.Types.IsolationLevel.Unspecified + }; + } + + private static IsolationLevel TranslateIsolationLevel(TransactionOptions.Types.IsolationLevel isolationLevel) + { + switch (isolationLevel) + { + case TransactionOptions.Types.IsolationLevel.Unspecified: + return IsolationLevel.Unspecified; + case TransactionOptions.Types.IsolationLevel.RepeatableRead: + return IsolationLevel.RepeatableRead; + case TransactionOptions.Types.IsolationLevel.Serializable: + return IsolationLevel.Serializable; + default: + throw new ArgumentOutOfRangeException(nameof(isolationLevel), isolationLevel, + "unsupported isolation level"); + } + } + + protected override void Dispose(bool disposing) + { + if (!IsCompleted) + { + // Do a shoot-and-forget rollback. + RollbackAsync(CancellationToken.None); + } + _disposed = true; + base.Dispose(disposing); + } + + public override ValueTask DisposeAsync() + { + if (!IsCompleted) + { + // Do a shoot-and-forget rollback. + RollbackAsync(CancellationToken.None); + } + _disposed = true; + return base.DisposeAsync(); + } + + private void CheckDisposed() + { + ObjectDisposedException.ThrowIf(_disposed, this); + } + + public override void Commit() + { + EndTransaction(() => LibConnection.Commit()); + } + + public override Task CommitAsync(CancellationToken cancellationToken = default) + { + return EndTransactionAsync(() => LibConnection.CommitAsync(cancellationToken)); + } + + public override void Rollback() + { + EndTransaction(() => LibConnection.Rollback()); + } + + public override Task RollbackAsync(CancellationToken cancellationToken = default) + { + return EndTransactionAsync(() => LibConnection.RollbackAsync(cancellationToken)); + } + + private void EndTransaction(Action endTransactionMethod) + { + CheckDisposed(); + GaxPreconditions.CheckState(!IsCompleted, "This transaction is no longer active"); + try + { + endTransactionMethod(); + } + finally + { + _spannerConnection?.ClearTransaction(); + _spannerConnection = null; + } + } + + private Task EndTransactionAsync(Func endTransactionMethod) + { + CheckDisposed(); + GaxPreconditions.CheckState(!IsCompleted, "This transaction is no longer active"); + try + { + return endTransactionMethod(); + } + finally + { + _spannerConnection?.ClearTransaction(); + _spannerConnection = null; + } + } +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net/TypeConversion.cs b/drivers/spanner-ado-net/spanner-ado-net/TypeConversion.cs new file mode 100644 index 00000000..d0bd9a1c --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/TypeConversion.cs @@ -0,0 +1,89 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +using System; +using System.Collections.Generic; +using System.Data; +using TypeCode = Google.Cloud.Spanner.V1.TypeCode; + +namespace Google.Cloud.Spanner.DataProvider; + +internal static class TypeConversion +{ + private static readonly Dictionary SDbTypeToSpannerTypeMapping = new (); + + static TypeConversion() + { + SDbTypeToSpannerTypeMapping[DbType.Date] = new V1.Type { Code = TypeCode.Date }; + SDbTypeToSpannerTypeMapping[DbType.Binary] = new V1.Type { Code = TypeCode.Bytes }; + SDbTypeToSpannerTypeMapping[DbType.Boolean] = new V1.Type { Code = TypeCode.Bool }; + SDbTypeToSpannerTypeMapping[DbType.Double] = new V1.Type { Code = TypeCode.Float64 }; + SDbTypeToSpannerTypeMapping[DbType.Single] = new V1.Type { Code = TypeCode.Float32 }; + SDbTypeToSpannerTypeMapping[DbType.Guid] = new V1.Type { Code = TypeCode.Uuid }; + + var numericType = new V1.Type { Code = TypeCode.Numeric }; + SDbTypeToSpannerTypeMapping[DbType.Decimal] = numericType; + SDbTypeToSpannerTypeMapping[DbType.VarNumeric] = numericType; + + var timestampType = new V1.Type { Code = TypeCode.Timestamp }; + SDbTypeToSpannerTypeMapping[DbType.DateTime] = timestampType; + SDbTypeToSpannerTypeMapping[DbType.DateTime2] = timestampType; + SDbTypeToSpannerTypeMapping[DbType.DateTimeOffset] = timestampType; + + var int64Type = new V1.Type { Code = TypeCode.Int64 }; + SDbTypeToSpannerTypeMapping[DbType.Byte] = int64Type; + SDbTypeToSpannerTypeMapping[DbType.Int16] = int64Type; + SDbTypeToSpannerTypeMapping[DbType.Int32] = int64Type; + SDbTypeToSpannerTypeMapping[DbType.Int64] = int64Type; + SDbTypeToSpannerTypeMapping[DbType.SByte] = int64Type; + SDbTypeToSpannerTypeMapping[DbType.UInt16] = int64Type; + SDbTypeToSpannerTypeMapping[DbType.UInt32] = int64Type; + SDbTypeToSpannerTypeMapping[DbType.UInt64] = int64Type; + + var stringType = new V1.Type { Code = TypeCode.String }; + SDbTypeToSpannerTypeMapping[DbType.String] = stringType; + SDbTypeToSpannerTypeMapping[DbType.StringFixedLength] = stringType; + SDbTypeToSpannerTypeMapping[DbType.AnsiString] = stringType; + SDbTypeToSpannerTypeMapping[DbType.AnsiStringFixedLength] = stringType; + } + + internal static V1.Type? GetSpannerType(DbType? dbType) + { + return dbType == null ? null : SDbTypeToSpannerTypeMapping.GetValueOrDefault(dbType.Value); + } + + internal static System.Type GetSystemType(V1.Type type) => GetSystemType(type.Code); + + internal static System.Type GetSystemType(TypeCode code) + { + return code switch + { + TypeCode.Bool => typeof(bool), + TypeCode.Bytes => typeof(byte[]), + TypeCode.Date => typeof(DateOnly), + TypeCode.Enum => typeof(long), + TypeCode.Float32 => typeof(float), + TypeCode.Float64 => typeof(double), + TypeCode.Int64 => typeof(long), + TypeCode.Interval => typeof(TimeSpan), + TypeCode.Json => typeof(string), + TypeCode.Numeric => typeof(decimal), + TypeCode.Proto => typeof(byte[]), + TypeCode.String => typeof(string), + TypeCode.Timestamp => typeof(DateTime), + TypeCode.Uuid => typeof(Guid), + _ => typeof(string) + }; + } +} \ No newline at end of file diff --git a/drivers/spanner-ado-net/spanner-ado-net/publish.sh b/drivers/spanner-ado-net/spanner-ado-net/publish.sh new file mode 100644 index 00000000..4425ca25 --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/publish.sh @@ -0,0 +1,11 @@ +VERSION=$(date -u +"1.0.0-alpha.%Y%m%d%H%M%S") + +echo "Publishing as version $VERSION" +sed -i "" "s|.*|$VERSION|g" spanner-ado-net.csproj + +rm -rf bin/Release +dotnet pack +dotnet nuget push \ + bin/Release/Alpha.Google.Cloud.Spanner.DataProvider.*.nupkg \ + --api-key $NUGET_API_KEY \ + --source https://api.nuget.org/v3/index.json diff --git a/drivers/spanner-ado-net/spanner-ado-net/spanner-ado-net.csproj b/drivers/spanner-ado-net/spanner-ado-net/spanner-ado-net.csproj new file mode 100644 index 00000000..8d70844c --- /dev/null +++ b/drivers/spanner-ado-net/spanner-ado-net/spanner-ado-net.csproj @@ -0,0 +1,27 @@ + + + + net8.0 + Google.Cloud.Spanner.DataProvider + enable + default + Google.Cloud.Spanner.DataProvider + Alpha.Google.Cloud.Spanner.DataProvider + .NET Data Provider for Spanner + Google + 1.0.0-alpha.20251003170157 + ADO.NET Data Provider. + +Alpha version: Not for production use + Apache v2.0 + https://github.com/googleapis/go-sql-spanner/drivers/spanner-ado-net/LICENSE + https://github.com/googleapis/go-sql-spanner/drivers/spanner-ado-net + + + + + + + + + diff --git a/parser/statement_parser.go b/parser/statement_parser.go index d86d8295..ef0875a8 100644 --- a/parser/statement_parser.go +++ b/parser/statement_parser.go @@ -708,6 +708,23 @@ const ( StatementTypeClientSide ) +func (st StatementType) String() string { + switch st { + case StatementTypeQuery: + return "Query" + case StatementTypeDml: + return "DML" + case StatementTypeDdl: + return "DDL" + case StatementTypeClientSide: + return "ClientSide" + case StatementTypeUnknown: + default: + return "Unknown" + } + return "Unknown" +} + // DmlType designates the type of modification that a DML statement will execute. type DmlType int diff --git a/spannerlib/api/batch_test.go b/spannerlib/api/batch_test.go index 1ae9af14..957d7faa 100644 --- a/spannerlib/api/batch_test.go +++ b/spannerlib/api/batch_test.go @@ -236,3 +236,41 @@ func TestExecuteDdlBatchInTransaction(t *testing.T) { t.Fatalf("ClosePool returned unexpected error: %v", err) } } + +func TestExecuteQueryInBatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + server, teardown := setupMockServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + poolId, err := CreatePool(ctx, dsn) + if err != nil { + t.Fatalf("CreatePool returned unexpected error: %v", err) + } + connId, err := CreateConnection(ctx, poolId) + if err != nil { + t.Fatalf("CreateConnection returned unexpected error: %v", err) + } + + // Try to execute a batch with queries. This should fail. + request := &spannerpb.ExecuteBatchDmlRequest{Statements: []*spannerpb.ExecuteBatchDmlRequest_Statement{ + {Sql: "select 1"}, + {Sql: "select 2"}, + }} + _, err = ExecuteBatch(ctx, poolId, connId, request) + if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := err.Error(), "rpc error: code = InvalidArgument desc = unsupported statement for batching: select 1"; g != w { + t.Fatalf("error message mismatch\n Got: %v\nWant: %v", g, w) + } + + if err := CloseConnection(ctx, poolId, connId); err != nil { + t.Fatalf("CloseConnection returned unexpected error: %v", err) + } + if err := ClosePool(ctx, poolId); err != nil { + t.Fatalf("ClosePool returned unexpected error: %v", err) + } +} diff --git a/spannerlib/api/connection.go b/spannerlib/api/connection.go index 22f911cd..2f0b611b 100644 --- a/spannerlib/api/connection.go +++ b/spannerlib/api/connection.go @@ -445,6 +445,9 @@ func extractParams(directExecuteContext context.Context, statement *spannerpb.Ex ReturnResultSetStats: true, DirectExecuteQuery: true, DirectExecuteContext: directExecuteContext, + QueryOptions: spanner.QueryOptions{ + Mode: &statement.QueryMode, + }, }) if statement.Params != nil { if statement.ParamTypes == nil { @@ -480,7 +483,7 @@ func determineBatchType(conn *Connection, statements []*spannerpb.ExecuteBatchDm } else if firstStatementType == parser.StatementTypeDdl { batchType = parser.BatchTypeDdl } else { - return status.Errorf(codes.InvalidArgument, "unsupported statement type for batching: %v", firstStatementType) + return status.Errorf(codes.InvalidArgument, "unsupported statement for batching: %v", statements[0].Sql) } for i, statement := range statements { if i > 0 { diff --git a/spannerlib/grpc-server/server.go b/spannerlib/grpc-server/server.go index 206237c4..713e5ab6 100644 --- a/spannerlib/grpc-server/server.go +++ b/spannerlib/grpc-server/server.go @@ -46,6 +46,18 @@ func main() { if err != nil { log.Fatalf("failed to listen: %v\n", err) } + grpcServer, err := createServer() + if err != nil { + log.Fatalf("failed to create server: %v\n", err) + } + log.Printf("Starting gRPC server on %s\n", lis.Addr().String()) + err = grpcServer.Serve(lis) + if err != nil { + log.Printf("failed to serve: %v\n", err) + } +} + +func createServer() (*grpc.Server, error) { var opts []grpc.ServerOption // Set a max message size that is essentially no limit. opts = append(opts, grpc.MaxRecvMsgSize(math.MaxInt32)) @@ -53,11 +65,8 @@ func main() { server := spannerLibServer{} pb.RegisterSpannerLibServer(grpcServer, &server) - log.Printf("Starting gRPC server on %s\n", lis.Addr().String()) - err = grpcServer.Serve(lis) - if err != nil { - log.Printf("failed to serve: %v\n", err) - } + + return grpcServer, nil } var _ pb.SpannerLibServer = &spannerLibServer{} diff --git a/spannerlib/grpc-server/server_test.go b/spannerlib/grpc-server/server_test.go index 8c6879bb..6701ef25 100644 --- a/spannerlib/grpc-server/server_test.go +++ b/spannerlib/grpc-server/server_test.go @@ -2,6 +2,8 @@ package main import ( "context" + cryptorand "crypto/rand" + "encoding/base64" "errors" "fmt" "io" @@ -236,6 +238,74 @@ func TestExecuteStreaming(t *testing.T) { } } +func TestLargeMessage(t *testing.T) { + t.Parallel() + ctx := context.Background() + + server, teardown := setupMockSpannerServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + client, cleanup := startTestSpannerLibServer(t) + defer cleanup() + + pool, err := client.CreatePool(ctx, &pb.CreatePoolRequest{ConnectionString: dsn}) + if err != nil { + t.Fatalf("failed to create pool: %v", err) + } + connection, err := client.CreateConnection(ctx, &pb.CreateConnectionRequest{Pool: pool}) + if err != nil { + t.Fatalf("failed to create connection: %v", err) + } + + query := "insert into foo (value) values (@value)" + b := make([]byte, 10_000_000) + n, err := cryptorand.Read(b) + if err != nil { + t.Fatalf("failed to read value: %v", err) + } + if g, w := n, len(b); g != w { + t.Fatalf("length mismatch\n Got: %v\nWant: %v", g, w) + } + value := base64.StdEncoding.EncodeToString(b) + _ = server.TestSpanner.PutStatementResult(query, &testutil.StatementResult{ + Type: testutil.StatementResultUpdateCount, + UpdateCount: 1, + }) + stream, err := client.ExecuteStreaming(ctx, &pb.ExecuteRequest{ + Connection: connection, + ExecuteSqlRequest: &sppb.ExecuteSqlRequest{ + Sql: query, + Params: &structpb.Struct{ + Fields: map[string]*structpb.Value{ + "value": {Kind: &structpb.Value_StringValue{StringValue: value}}, + }, + }, + }, + }) + if err != nil { + t.Fatalf("failed to execute: %v", err) + } + numRows := 0 + for { + row, err := stream.Recv() + if err != nil { + t.Fatalf("failed to receive row: %v", err) + } + if len(row.Data) == 0 { + break + } + numRows++ + } + if g, w := numRows, 0; g != w { + t.Fatalf("num rows mismatch\n Got: %v\nWant: %v", g, w) + } + + if _, err := client.ClosePool(ctx, pool); err != nil { + t.Fatalf("failed to close pool: %v", err) + } +} + func TestExecuteStreamingWithTimeout(t *testing.T) { t.Parallel() ctx := context.Background() @@ -660,11 +730,10 @@ func startTestSpannerLibServer(t *testing.T) (client pb.SpannerLibClient, cleanu t.Fatalf("failed to listen: %v\n", err) } addr := lis.Addr().String() - var opts []grpc.ServerOption - grpcServer := grpc.NewServer(opts...) - - server := spannerLibServer{} - pb.RegisterSpannerLibServer(grpcServer, &server) + grpcServer, err := createServer() + if err != nil { + t.Fatalf("failed to create server: %v\n", err) + } go func() { _ = grpcServer.Serve(lis) }() conn, err := grpc.NewClient( diff --git a/spannerlib/wrappers/spannerlib-dotnet/build.sh b/spannerlib/wrappers/spannerlib-dotnet/build.sh index c763088e..daca91b6 100755 --- a/spannerlib/wrappers/spannerlib-dotnet/build.sh +++ b/spannerlib/wrappers/spannerlib-dotnet/build.sh @@ -38,8 +38,10 @@ echo "Skip windows: $SKIP_WINDOWS" # Remove existing builds rm -r ./spannerlib-dotnet-native/libraries 2> /dev/null +rm -r ./spannerlib-dotnet-grpc-server/binaries 2> /dev/null rm -r ./*/bin 2> /dev/null rm -r ./*/obj 2> /dev/null +dotnet nuget locals global-packages --clear # Build gRPC server echo "Building gRPC server..." diff --git a/spannerlib/wrappers/spannerlib-dotnet/publish.sh b/spannerlib/wrappers/spannerlib-dotnet/publish.sh index 4c703e2f..44b73e23 100755 --- a/spannerlib/wrappers/spannerlib-dotnet/publish.sh +++ b/spannerlib/wrappers/spannerlib-dotnet/publish.sh @@ -20,6 +20,9 @@ fi # Build all components ./build.sh +# Remove existing packages. +find ./**/bin/Release -type f -name "Alpha*.nupkg" -exec rm {} \; + # Pack and publish to nuget dotnet pack find ./**/bin/Release -type f -name "Alpha*.nupkg" -exec \ diff --git a/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-mockserver/MockSpannerServer.cs b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-mockserver/MockSpannerServer.cs index 129449af..8b5a9010 100644 --- a/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-mockserver/MockSpannerServer.cs +++ b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-mockserver/MockSpannerServer.cs @@ -14,6 +14,8 @@ using System.Collections; using System.Collections.Concurrent; +using System.Diagnostics; +using System.Reflection; using Google.Cloud.Spanner.Admin.Database.V1; using Google.Cloud.Spanner.Common.V1; using Google.Cloud.Spanner.V1; @@ -22,6 +24,7 @@ using Google.Protobuf.WellKnownTypes; using Google.Rpc; using Grpc.Core; +using Enum = System.Enum; using Status = Google.Rpc.Status; using GrpcCore = Grpc.Core; @@ -356,7 +359,7 @@ public void AddOrUpdateExecutionTime(string method, ExecutionTime executionTime) ); } - private void AddDialectResult() + public void AddDialectResult(DatabaseDialect dialect = DatabaseDialect.GoogleStandardSql) { AddOrUpdateStatementResult(SDialectQuery, StatementResult.CreateResultSet( @@ -366,10 +369,23 @@ private void AddDialectResult() }, new List { - new object[] { "GOOGLE_STANDARD_SQL" }, + new object[] { GetEnumOriginalName(dialect) }, })); } - + + string GetEnumOriginalName(Enum enumValue) + { + var enumType = enumValue.GetType(); + var enumValueName = enumValue.ToString(); + var enumValueInfo = enumType.GetMember(enumValueName).Single(); + var attribute = enumValueInfo.GetCustomAttribute(); + if (attribute is null) + { + throw new InvalidOperationException($"Attribute '{nameof(Protobuf.Reflection.OriginalNameAttribute)}' not found on enum '{enumType}'."); + } + return attribute.Name; + } + internal void AbortTransaction(string transactionId) { _abortedTransactions.TryAdd(ByteString.FromBase64(transactionId), true); @@ -383,8 +399,31 @@ internal void AbortNextStatement() } } + public void ClearRequests() + { + _requests.Clear(); + } + public IEnumerable Requests => new List(_requests).AsReadOnly(); + public bool WaitForRequestsToContain(Func predicate) + { + return WaitForRequestsToContain(predicate, new TimeSpan(5 * TimeSpan.TicksPerSecond)); + } + + public bool WaitForRequestsToContain(Func predicate, TimeSpan timeout) + { + var stopwatch = new Stopwatch(); + while (stopwatch.Elapsed < timeout) + { + if (Requests.Any(predicate)) + { + return true; + } + } + return false; + } + public IEnumerable Contexts => new List(_contexts).AsReadOnly(); public IEnumerable Headers => new List(_headers).AsReadOnly(); @@ -713,7 +752,7 @@ public override async Task ExecuteStreamingSql(ExecuteSqlRequest request, IServe } else { - throw new RpcException(new GrpcCore.Status(StatusCode.InvalidArgument, $"No result found for {request.Sql}")); + throw new RpcException(new GrpcCore.Status(StatusCode.InvalidArgument, $"No result found for {request.Sql[0..Math.Min(request.Sql.Length, 5000)]}")); } } diff --git a/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-mockserver/SpannerConverter.cs b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-mockserver/SpannerConverter.cs index 9b64ffab..9f1a31c7 100644 --- a/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-mockserver/SpannerConverter.cs +++ b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-mockserver/SpannerConverter.cs @@ -78,11 +78,24 @@ internal static Value ToProtobufValue(Spanner.V1.Type type, object? value) StringValue = XmlConvert.ToString(Convert.ToDateTime(value, InvariantCulture), XmlDateTimeSerializationMode.Utc) }; case TypeCode.Date: + if (value is DateOnly dateOnly) + { + return new Value + { + StringValue = dateOnly.ToString("yyyy-MM-dd"), + }; + } return new Value { StringValue = StripTimePart( XmlConvert.ToString(Convert.ToDateTime(value, InvariantCulture), XmlDateTimeSerializationMode.Utc)) }; + case TypeCode.Interval: + if (value is TimeSpan timeSpan) + { + return Value.ForString(XmlConvert.ToString(timeSpan)); + } + return Value.ForString(value.ToString()); case TypeCode.Json: if (value is string stringValue) { diff --git a/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-mockserver/SpannerMockServerFixture.cs b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-mockserver/SpannerMockServerFixture.cs index c9d21e6a..aa25eac4 100644 --- a/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-mockserver/SpannerMockServerFixture.cs +++ b/spannerlib/wrappers/spannerlib-dotnet/spannerlib-dotnet-mockserver/SpannerMockServerFixture.cs @@ -18,6 +18,7 @@ using Microsoft.AspNetCore.Hosting.Server.Features; using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; namespace Google.Cloud.SpannerLib.MockServer; @@ -53,10 +54,18 @@ public SpannerMockServerFixture() var builder = WebHost.CreateDefaultBuilder(); builder.ConfigureAppConfiguration(configurationBuilder => configurationBuilder.AddJsonFile("appsettings.json", true)); builder.UseStartup(_ => new MockServerStartup(SpannerMock, DatabaseAdminMock)); + builder.ConfigureServices(services => + { + services.AddGrpc(options => + { + options.MaxReceiveMessageSize = null; + }); + }); builder.ConfigureKestrel(options => { // Setup a HTTP/2 endpoint without TLS. options.Listen(endpoint, o => o.Protocols = HttpProtocols.Http2); + options.Limits.MaxRequestBodySize = 100_000_000; }); _host = builder.Build(); _host.Start(); diff --git a/testutil/mocked_inmem_server.go b/testutil/mocked_inmem_server.go index 5bf012ff..1bb70452 100644 --- a/testutil/mocked_inmem_server.go +++ b/testutil/mocked_inmem_server.go @@ -17,6 +17,7 @@ package testutil import ( "encoding/base64" "fmt" + "math" "net" "strconv" "testing" @@ -111,7 +112,10 @@ func (s *MockedSpannerInMemTestServer) setupMockedServerWithAddr(t *testing.T, a s.setupSelect1Result() s.setupFooResults() s.setupSingersResults() - s.server = grpc.NewServer() + var serverOpts []grpc.ServerOption + // Set a max message size that is essentially no limit. + serverOpts = append(serverOpts, grpc.MaxRecvMsgSize(math.MaxInt32)) + s.server = grpc.NewServer(serverOpts...) spannerpb.RegisterSpannerServer(s.server, s.TestSpanner) instancepb.RegisterInstanceAdminServer(s.server, s.TestInstanceAdmin) databasepb.RegisterDatabaseAdminServer(s.server, s.TestDatabaseAdmin)