Skip to content

Commit a7a4aac

Browse files
committed
Add grpcio support.
1 parent 124613f commit a7a4aac

File tree

13 files changed

+765
-122
lines changed

13 files changed

+765
-122
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,12 @@ protoc \
533533
/usr/local/include/google/protobuf/*.proto
534534
```
535535

536+
### Using grpcio library instead of grpclib
537+
538+
In order to use the `grpcio` library instead of `grpclib`, you can use the `--custom_opt=grpcio`
539+
option when running the `protoc` command.
540+
This will generate stubs compatible with the `grpcio` library.
541+
536542
### TODO
537543

538544
- [x] Fixed length fields

poetry.lock

Lines changed: 94 additions & 78 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dynamic = ["dependencies"]
1919
# The Ruff version is pinned. To update it, also update it in .pre-commit-config.yaml
2020
ruff = { version = "~0.9.1", optional = true }
2121
grpclib = "^0.4.1"
22+
grpcio = { version = ">=1.73.0", optional = true }
2223
jinja2 = { version = ">=3.0.3", optional = true }
2324
python-dateutil = "^2.8"
2425
typing-extensions = "^4.7.1"
@@ -45,13 +46,15 @@ pydantic = ">=2.0,<3"
4546
protobuf = "^5"
4647
cachelib = "^0.13.0"
4748
tomlkit = ">=0.7.0"
49+
grpcio-testing = "^1.54.2"
4850

4951
[project.scripts]
5052
protoc-gen-python_betterproto = "betterproto.plugin:main"
5153

5254
[project.optional-dependencies]
5355
compiler = ["ruff", "jinja2"]
5456
rust-codec = ["betterproto-rust-codec"]
57+
grpcio = ["grpcio"]
5558

5659
[tool.ruff]
5760
extend-exclude = ["tests/output_*"]
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
from abc import ABC
2+
from typing import (
3+
TYPE_CHECKING,
4+
AsyncIterable,
5+
AsyncIterator,
6+
Iterable,
7+
Mapping,
8+
Optional,
9+
Union,
10+
)
11+
12+
import grpc
13+
14+
15+
if TYPE_CHECKING:
16+
from .._types import (
17+
T,
18+
IProtoMessage,
19+
)
20+
21+
22+
Value = Union[str, bytes]
23+
MetadataLike = Union[Mapping[str, Value], Iterable[tuple[str, Value]]]
24+
MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]]
25+
26+
27+
class ServiceStub(ABC):
28+
29+
def __init__(
30+
self,
31+
channel: grpc.aio.Channel,
32+
*,
33+
timeout: Optional[float] = None,
34+
metadata: Optional[MetadataLike] = None,
35+
) -> None:
36+
self.channel = channel
37+
self.timeout = timeout
38+
self.metadata = metadata
39+
40+
def _resolve_request_kwargs(
41+
self,
42+
timeout: Optional[float],
43+
metadata: Optional[MetadataLike],
44+
):
45+
# Avoid creating dict if no overrides needed
46+
if timeout is None and metadata is None:
47+
# Return cached default kwargs if both timeout and metadata are None
48+
if not hasattr(self, '_default_kwargs'):
49+
self._default_kwargs = {
50+
"timeout": self.timeout,
51+
"metadata": self.metadata,
52+
}
53+
return self._default_kwargs
54+
55+
return {
56+
"timeout": self.timeout if timeout is None else timeout,
57+
"metadata": self.metadata if metadata is None else metadata,
58+
}
59+
60+
async def _unary_unary(
61+
self,
62+
stub_method: grpc.aio.UnaryUnaryMultiCallable,
63+
request: "IProtoMessage",
64+
*,
65+
timeout: Optional[float] = None,
66+
metadata: Optional[MetadataLike] = None,
67+
) -> "T":
68+
return await stub_method(
69+
request,
70+
**self._resolve_request_kwargs(timeout, metadata),
71+
)
72+
73+
async def _unary_stream(
74+
self,
75+
stub_method: grpc.aio.UnaryStreamMultiCallable,
76+
request: "IProtoMessage",
77+
*,
78+
timeout: Optional[float] = None,
79+
metadata: Optional[MetadataLike] = None,
80+
) -> AsyncIterator["T"]:
81+
call = stub_method(
82+
request,
83+
**self._resolve_request_kwargs(timeout, metadata),
84+
)
85+
async for response in call:
86+
yield response
87+
88+
async def _stream_unary(
89+
self,
90+
stub_method: grpc.aio.StreamUnaryMultiCallable,
91+
request_iterator: MessageSource,
92+
*,
93+
timeout: Optional[float] = None,
94+
metadata: Optional[MetadataLike] = None,
95+
) -> "T":
96+
call = stub_method(
97+
self._wrap_message_iterator(request_iterator),
98+
**self._resolve_request_kwargs(timeout, metadata),
99+
)
100+
return await call
101+
102+
async def _stream_stream(
103+
self,
104+
stub_method: grpc.aio.StreamStreamMultiCallable,
105+
request_iterator: MessageSource,
106+
*,
107+
timeout: Optional[float] = None,
108+
metadata: Optional[MetadataLike] = None,
109+
) -> AsyncIterator["T"]:
110+
call = stub_method(
111+
self._wrap_message_iterator(request_iterator),
112+
**self._resolve_request_kwargs(timeout, metadata),
113+
)
114+
async for response in call:
115+
yield response
116+
117+
@staticmethod
118+
def _wrap_message_iterator(
119+
messages: MessageSource,
120+
) -> AsyncIterator["IProtoMessage"]:
121+
if hasattr(messages, '__aiter__'):
122+
async def async_wrapper():
123+
async for message in messages:
124+
yield message
125+
return async_wrapper()
126+
else:
127+
async def sync_wrapper():
128+
for message in messages:
129+
yield message
130+
return sync_wrapper()
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from abc import ABC, abstractmethod
2+
from typing import TYPE_CHECKING, Dict
3+
4+
5+
if TYPE_CHECKING:
6+
import grpc
7+
8+
9+
class ServiceBase(ABC):
10+
11+
@property
12+
@abstractmethod
13+
def __rpc_methods__(self) -> Dict[str, "grpc.RpcMethodHandler"]: ...
14+
15+
@property
16+
@abstractmethod
17+
def __proto_path__(self) -> str: ...
18+
19+
20+
def register_servicers(server: "grpc.aio.Server", *servicers: ServiceBase):
21+
from grpc import method_handlers_generic_handler
22+
23+
server.add_generic_rpc_handlers(
24+
tuple(
25+
method_handlers_generic_handler(
26+
servicer.__proto_path__, servicer.__rpc_methods__
27+
)
28+
for servicer in servicers
29+
)
30+
)

src/betterproto/plugin/models.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ class OutputTemplate:
270270
imports_type_checking_only: Set[str] = field(default_factory=set)
271271
pydantic_dataclasses: bool = False
272272
output: bool = True
273+
use_grpcio: bool = False
273274
typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler)
274275

275276
@property
@@ -697,18 +698,24 @@ class ServiceMethodCompiler(ProtoContentBase):
697698
proto_obj: MethodDescriptorProto
698699
path: List[int] = PLACEHOLDER
699700
comment_indent: int = 8
701+
use_grpcio: bool = False
700702

701703
def __post_init__(self) -> None:
702704
# Add method to service
703705
self.parent.methods.append(self)
704-
705-
self.output_file.imports_type_checking_only.add("import grpclib.server")
706-
self.output_file.imports_type_checking_only.add(
707-
"from betterproto.grpc.grpclib_client import MetadataLike"
708-
)
709-
self.output_file.imports_type_checking_only.add(
710-
"from grpclib.metadata import Deadline"
711-
)
706+
if self.use_grpcio:
707+
self.output_file.imports_type_checking_only.add("import grpc.aio")
708+
self.output_file.imports_type_checking_only.add(
709+
"from betterproto.grpc.grpcio_client import MetadataLike"
710+
)
711+
else:
712+
self.output_file.imports_type_checking_only.add("import grpclib.server")
713+
self.output_file.imports_type_checking_only.add(
714+
"from betterproto.grpc.grpclib_client import MetadataLike"
715+
)
716+
self.output_file.imports_type_checking_only.add(
717+
"from grpclib.metadata import Deadline"
718+
)
712719

713720
super().__post_init__() # check for unset fields
714721

src/betterproto/plugin/parser.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
8080
response.supported_features = CodeGeneratorResponseFeature.FEATURE_PROTO3_OPTIONAL
8181

8282
request_data = PluginRequestCompiler(plugin_request_obj=request)
83+
use_grpcio = "USE_GRPCIO" in plugin_options
8384
# Gather output packages
8485
for proto_file in request.proto_file:
8586
output_package_name = proto_file.package
@@ -90,7 +91,8 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
9091
)
9192
# Add this input file to the output corresponding to this package
9293
request_data.output_packages[output_package_name].input_files.append(proto_file)
93-
94+
if use_grpcio:
95+
request_data.output_packages[output_package_name].use_grpcio = True
9496
if (
9597
proto_file.package == "google.protobuf"
9698
and "INCLUDE_GOOGLE" not in plugin_options
@@ -143,7 +145,9 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
143145
for output_package_name, output_package in request_data.output_packages.items():
144146
for proto_input_file in output_package.input_files:
145147
for index, service in enumerate(proto_input_file.service):
146-
read_protobuf_service(proto_input_file, service, index, output_package)
148+
read_protobuf_service(
149+
proto_input_file, service, index, output_package, use_grpcio
150+
)
147151

148152
# Generate output files
149153
output_paths: Set[pathlib.Path] = set()
@@ -253,6 +257,7 @@ def read_protobuf_service(
253257
service: ServiceDescriptorProto,
254258
index: int,
255259
output_package: OutputTemplate,
260+
use_grpcio: bool = False,
256261
) -> None:
257262
service_data = ServiceCompiler(
258263
source_file=source_file,
@@ -266,4 +271,5 @@ def read_protobuf_service(
266271
parent=service_data,
267272
proto_obj=method,
268273
path=[6, index, 2, j],
274+
use_grpcio=use_grpcio,
269275
)

src/betterproto/templates/header.py.j2

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,16 @@ from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% i
4242

4343
{% endif %}
4444

45+
46+
{% if output_file.use_grpcio %}
47+
import grpc
48+
from betterproto.grpc.grpcio_client import ServiceStub
49+
from betterproto.grpc.grpcio_server import ServiceBase
50+
{% endif %}
51+
4552
import betterproto
46-
{% if output_file.services %}
53+
{% if not output_file.use_grpcio %}
54+
from betterproto.grpc.grpclib_client import ServiceStub
4755
from betterproto.grpc.grpclib_server import ServiceBase
4856
import grpclib
4957
{% endif %}

0 commit comments

Comments
 (0)