Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/wellknown/proto/wellknown.proto
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,13 @@ service Admin {
rpc EmptyCall(google.protobuf.Empty) returns (google.protobuf.Empty);
rpc StringCall(google.protobuf.StringValue) returns (google.protobuf.Empty);
rpc AnyCall(google.protobuf.Any) returns (google.protobuf.Empty);
// Wrapper types that map to primitives
rpc BoolCall(google.protobuf.BoolValue) returns (google.protobuf.BoolValue);
rpc Int32Call(google.protobuf.Int32Value) returns (google.protobuf.Int32Value);
rpc Int64Call(google.protobuf.Int64Value) returns (google.protobuf.Int64Value);
rpc UInt32Call(google.protobuf.UInt32Value) returns (google.protobuf.UInt32Value);
rpc UInt64Call(google.protobuf.UInt64Value) returns (google.protobuf.UInt64Value);
rpc FloatCall(google.protobuf.FloatValue) returns (google.protobuf.FloatValue);
rpc DoubleCall(google.protobuf.DoubleValue) returns (google.protobuf.DoubleValue);
rpc BytesCall(google.protobuf.BytesValue) returns (google.protobuf.BytesValue);
}
90 changes: 33 additions & 57 deletions tonic-prost-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

use proc_macro2::TokenStream;
use prost_build::{Method, Service};
use quote::{quote, ToTokens};
use quote::ToTokens;
use std::{
collections::HashSet,
ffi::OsString,
Expand Down Expand Up @@ -192,67 +192,43 @@ impl tonic_build::Method for TonicBuildMethod {
proto_path: &str,
compile_well_known_types: bool,
) -> (TokenStream, TokenStream) {
let request = if is_google_type(&self.prost_method.input_type) && !compile_well_known_types
{
// For well-known types, map to absolute paths that will work with super::
match self.prost_method.input_type.as_str() {
".google.protobuf.Empty" => quote!(()),
".google.protobuf.Any" => quote!(::prost_types::Any),
".google.protobuf.StringValue" => quote!(::prost::alloc::string::String),
_ => {
// For other google types, assume they're in prost_types
let type_name = self
.prost_method
.input_type
.trim_start_matches(".google.protobuf.")
.to_string();
syn::parse_str::<syn::Path>(&format!("::prost_types::{type_name}"))
.unwrap()
.to_token_stream()
}
}
} else if NON_PATH_TYPE_ALLOWLIST
.iter()
.any(|ty| self.prost_method.input_type.ends_with(ty))
{
self.prost_method.input_type.parse::<TokenStream>().unwrap()
} else {
// Check if this is an extern type that starts with :: or crate::
if self.prost_method.input_type.starts_with("::")
|| self.prost_method.input_type.starts_with("crate::")
// Use input_proto_type to detect google types, since input_type is already
// resolved by prost-build (e.g., ".google.protobuf.BoolValue" -> "bool")
let request =
if is_google_type(&self.prost_method.input_proto_type) && !compile_well_known_types {
// prost-build already resolved the type, use it directly
self.prost_method.input_type.parse::<TokenStream>().unwrap()
} else if NON_PATH_TYPE_ALLOWLIST
.iter()
.any(|ty| self.prost_method.input_type.ends_with(ty))
{
// This is an extern type, use it directly
self.prost_method.input_type.parse::<TokenStream>().unwrap()
} else {
// Replace dots with double colons for the type name
let rust_type = self.prost_method.input_type.replace('.', "::");
// Remove leading :: if present
let rust_type = rust_type.trim_start_matches("::");
syn::parse_str::<syn::Path>(&format!("{proto_path}::{rust_type}"))
.unwrap()
.to_token_stream()
}
};
// Check if this is an extern type that starts with :: or crate::
if self.prost_method.input_type.starts_with("::")
|| self.prost_method.input_type.starts_with("crate::")
{
// This is an extern type, use it directly
self.prost_method.input_type.parse::<TokenStream>().unwrap()
} else {
// Replace dots with double colons for the type name
let rust_type = self.prost_method.input_type.replace('.', "::");
// Remove leading :: if present
let rust_type = rust_type.trim_start_matches("::");
syn::parse_str::<syn::Path>(&format!("{proto_path}::{rust_type}"))
.unwrap()
.to_token_stream()
}
};

// Use output_proto_type to detect google types
let response =
if is_google_type(&self.prost_method.output_type) && !compile_well_known_types {
// For well-known types, map to absolute paths that will work with super::
match self.prost_method.output_type.as_str() {
".google.protobuf.Empty" => quote!(()),
".google.protobuf.Any" => quote!(::prost_types::Any),
".google.protobuf.StringValue" => quote!(::prost::alloc::string::String),
_ => {
// For other google types, assume they're in prost_types
let type_name = self
.prost_method
.output_type
.trim_start_matches(".google.protobuf.")
.to_string();
syn::parse_str::<syn::Path>(&format!("::prost_types::{type_name}"))
.unwrap()
.to_token_stream()
}
}
if is_google_type(&self.prost_method.output_proto_type) && !compile_well_known_types {
// prost-build already resolved the type, use it directly
self.prost_method
.output_type
.parse::<TokenStream>()
.unwrap()
} else if NON_PATH_TYPE_ALLOWLIST
.iter()
.any(|ty| self.prost_method.output_type.ends_with(ty))
Expand Down
140 changes: 111 additions & 29 deletions tonic-prost-build/src/tests.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
use super::*;
use prost_build::{Comments, Method};
use quote::quote;

fn create_test_method(input_type: String, output_type: String) -> TonicBuildMethod {
/// Create a test method with separate proto types and resolved rust types.
/// This reflects how prost-build actually works: input_proto_type is the original
/// protobuf type (e.g., ".google.protobuf.BoolValue") while input_type is the
/// resolved Rust type (e.g., "bool").
fn create_test_method_with_proto_types(
input_type: String,
output_type: String,
input_proto_type: String,
output_proto_type: String,
) -> TonicBuildMethod {
TonicBuildMethod {
prost_method: Method {
name: "TestMethod".to_string(),
Expand All @@ -12,10 +20,10 @@ fn create_test_method(input_type: String, output_type: String) -> TonicBuildMeth
trailing: vec![],
leading_detached: vec![],
},
input_type: input_type.clone(),
output_type: output_type.clone(),
input_proto_type: input_type,
output_proto_type: output_type,
input_type,
output_type,
input_proto_type,
output_proto_type,
client_streaming: false,
server_streaming: false,
options: prost_types::MethodOptions::default(),
Expand All @@ -24,53 +32,111 @@ fn create_test_method(input_type: String, output_type: String) -> TonicBuildMeth
}
}

/// Legacy helper for non-google types where proto type == rust type
fn create_test_method(input_type: String, output_type: String) -> TonicBuildMethod {
create_test_method_with_proto_types(
input_type.clone(),
output_type.clone(),
input_type,
output_type,
)
}

#[test]
fn test_request_response_name_google_types_not_compiled() {
// Test Google well-known types when compile_well_known_types is false
let test_cases = vec![
(".google.protobuf.Empty", quote!(())),
(".google.protobuf.Any", quote!(::prost_types::Any)),
// Test Google well-known types when compile_well_known_types is false.
// Reflect how prost-build resolves types:
// - proto_type is the original protobuf type (e.g., ".google.protobuf.BoolValue")
// - rust_type is what prost-build resolves it to (e.g., "bool")
let test_cases: Vec<(&str, &str, &str)> = vec![
// (proto_type, rust_type from prost-build, expected output)
(".google.protobuf.Empty", "()", "()"),
(
".google.protobuf.Any",
"::prost_types::Any",
":: prost_types :: Any",
),
(
".google.protobuf.StringValue",
quote!(::prost::alloc::string::String),
"::prost::alloc::string::String",
":: prost :: alloc :: string :: String",
),
(
".google.protobuf.Timestamp",
quote!(::prost_types::Timestamp),
"::prost_types::Timestamp",
":: prost_types :: Timestamp",
),
(
".google.protobuf.Duration",
"::prost_types::Duration",
":: prost_types :: Duration",
),
(
".google.protobuf.Value",
"::prost_types::Value",
":: prost_types :: Value",
),
// Wrapper types that map to primitives (the bug fix!)
(".google.protobuf.BoolValue", "bool", "bool"),
(".google.protobuf.Int32Value", "i32", "i32"),
(".google.protobuf.Int64Value", "i64", "i64"),
(".google.protobuf.UInt32Value", "u32", "u32"),
(".google.protobuf.UInt64Value", "u64", "u64"),
(".google.protobuf.FloatValue", "f32", "f32"),
(".google.protobuf.DoubleValue", "f64", "f64"),
(
".google.protobuf.BytesValue",
"::prost::alloc::vec::Vec<u8>",
":: prost :: alloc :: vec :: Vec < u8 >",
),
(".google.protobuf.Duration", quote!(::prost_types::Duration)),
(".google.protobuf.Value", quote!(::prost_types::Value)),
];

for (type_name, expected) in test_cases {
let method = create_test_method(type_name.to_string(), type_name.to_string());
for (proto_type, rust_type, expected) in test_cases {
let method = create_test_method_with_proto_types(
rust_type.to_string(),
rust_type.to_string(),
proto_type.to_string(),
proto_type.to_string(),
);
let (request, response) = method.request_response_name("super", false);

assert_eq!(
request.to_string(),
expected.to_string(),
"Failed for input type: {type_name}"
expected,
"Failed for input proto_type: {proto_type}, rust_type: {rust_type}"
);
assert_eq!(
response.to_string(),
expected.to_string(),
"Failed for output type: {type_name}"
expected,
"Failed for output proto_type: {proto_type}, rust_type: {rust_type}"
);
}
}

#[test]
fn test_request_response_name_google_types_compiled() {
// Test Google well-known types when compile_well_known_types is true
// Test Google well-known types when compile_well_known_types is true.
// When compile_well_known_types is true, prost-build doesn't resolve
// google types to external paths, so input_type == input_proto_type
// without the leading dot and with proper Rust path format
let test_cases = vec![
".google.protobuf.Empty",
".google.protobuf.Any",
".google.protobuf.StringValue",
".google.protobuf.Timestamp",
".google.protobuf.BoolValue",
];

for type_name in test_cases {
let method = create_test_method(type_name.to_string(), type_name.to_string());
// When compile_well_known_types is true, input_type is a path like
// "google.protobuf.Empty" (not resolved to "()" or "bool")
let rust_type = type_name.trim_start_matches('.');
let method = create_test_method_with_proto_types(
rust_type.to_string(),
rust_type.to_string(),
type_name.to_string(),
type_name.to_string(),
);
let (request, response) = method.request_response_name("super", true);

// When compile_well_known_types is true, it should use the normal path logic
Expand Down Expand Up @@ -212,25 +278,41 @@ fn test_request_response_name_different_proto_paths() {

#[test]
fn test_request_response_name_mixed_types() {
// Test with different request and response types
let method = create_test_method(
".google.protobuf.Empty".to_string(),
"mypackage.MyResponse".to_string(),
// Test with google type as request and regular type as response
let method = create_test_method_with_proto_types(
"()".to_string(), // rust type for Empty
"mypackage.MyResponse".to_string(), // rust type for regular message
".google.protobuf.Empty".to_string(), // proto type
".mypackage.MyResponse".to_string(), // proto type
);
let (request, response) = method.request_response_name("super", false);

assert_eq!(request.to_string(), "()");
assert_eq!(response.to_string(), "super :: mypackage :: MyResponse");

// Test with extern type as request and google type as response
let method = create_test_method(
"::external::Request".to_string(),
".google.protobuf.Any".to_string(),
let method = create_test_method_with_proto_types(
"::external::Request".to_string(), // rust type (extern path)
"::prost_types::Any".to_string(), // rust type for Any
".external.Request".to_string(), // proto type
".google.protobuf.Any".to_string(), // proto type
);
let (request, response) = method.request_response_name("super", false);

assert_eq!(request.to_string(), ":: external :: Request");
assert_eq!(response.to_string(), ":: prost_types :: Any");

// Test with BoolValue (primitive wrapper) as response
let method = create_test_method_with_proto_types(
"mypackage.MyRequest".to_string(), // rust type
"bool".to_string(), // rust type for BoolValue
".mypackage.MyRequest".to_string(), // proto type
".google.protobuf.BoolValue".to_string(), // proto type
);
let (request, response) = method.request_response_name("super", false);

assert_eq!(request.to_string(), "super :: mypackage :: MyRequest");
assert_eq!(response.to_string(), "bool");
}

#[test]
Expand Down
Loading