Skip to content

Commit bc6ae4d

Browse files
committed
Add grpcio support.
1 parent a3f7fd0 commit bc6ae4d

File tree

15 files changed

+783
-124
lines changed

15 files changed

+783
-124
lines changed

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,16 @@ protoc \
536536
/usr/local/include/google/protobuf/*.proto
537537
```
538538

539+
### Using grpcio library instead of grpclib
540+
541+
In order to use the `grpcio` library instead of `grpclib`, you can use the `--python_betterproto_opt=USE_GRPCIO`
542+
option when running the `protoc` command.
543+
This will generate stubs compatible with the `grpcio` library.
544+
545+
Example:
546+
```sh
547+
protoc -I . --python_betterproto_out=. --python_betterproto_opt=USE_GRPCIO demo.proto
548+
```
539549
### TODO
540550

541551
- [x] Fixed length fields

poetry.lock

Lines changed: 94 additions & 78 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dynamic = ["dependencies"]
1919
# The Ruff version is pinned. To update it, also update it in .pre-commit-config.yaml
2020
ruff = { version = "~0.9.1", optional = true }
2121
grpclib = "^0.4.1"
22+
grpcio = { version = ">=1.73.0", optional = true }
2223
jinja2 = { version = ">=3.0.3", optional = true }
2324
python-dateutil = "^2.8"
2425
typing-extensions = "^4.7.1"
@@ -45,13 +46,15 @@ pydantic = ">=2.0,<3"
4546
protobuf = "^5"
4647
cachelib = "^0.13.0"
4748
tomlkit = ">=0.7.0"
49+
grpcio-testing = "^1.54.2"
4850

4951
[project.scripts]
5052
protoc-gen-python_betterproto = "betterproto.plugin:main"
5153

5254
[project.optional-dependencies]
5355
compiler = ["ruff", "jinja2"]
5456
rust-codec = ["betterproto-rust-codec"]
57+
grpcio = ["grpcio"]
5558

5659
[tool.ruff]
5760
extend-exclude = ["tests/output_*"]
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from abc import ABC
2+
from typing import (
3+
TYPE_CHECKING,
4+
AsyncIterable,
5+
AsyncIterator,
6+
Iterable,
7+
Mapping,
8+
Optional,
9+
Union,
10+
)
11+
12+
import grpc
13+
14+
if TYPE_CHECKING:
15+
from .._types import (
16+
T,
17+
IProtoMessage,
18+
)
19+
20+
Value = Union[str, bytes]
21+
MetadataLike = Union[Mapping[str, Value], Iterable[tuple[str, Value]]]
22+
MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]]
23+
24+
25+
class ServiceStub(ABC):
26+
27+
def __init__(
28+
self,
29+
channel: grpc.aio.Channel,
30+
*,
31+
timeout: Optional[float] = None,
32+
metadata: Optional[MetadataLike] = None,
33+
) -> None:
34+
self.channel = channel
35+
self.timeout = timeout
36+
self.metadata = metadata
37+
38+
def _resolve_request_kwargs(
39+
self,
40+
timeout: Optional[float],
41+
metadata: Optional[MetadataLike],
42+
):
43+
return {
44+
"timeout": self.timeout if timeout is None else timeout,
45+
"metadata": self.metadata if metadata is None else metadata,
46+
}
47+
48+
async def _unary_unary(
49+
self,
50+
stub_method: grpc.aio.UnaryUnaryMultiCallable,
51+
request: "IProtoMessage",
52+
*,
53+
timeout: Optional[float] = None,
54+
metadata: Optional[MetadataLike] = None,
55+
) -> "T":
56+
return await stub_method(
57+
request,
58+
**self._resolve_request_kwargs(timeout, metadata),
59+
)
60+
61+
async def _unary_stream(
62+
self,
63+
stub_method: grpc.aio.UnaryStreamMultiCallable,
64+
request: "IProtoMessage",
65+
*,
66+
timeout: Optional[float] = None,
67+
metadata: Optional[MetadataLike] = None,
68+
) -> AsyncIterator["T"]:
69+
call = stub_method(
70+
request,
71+
**self._resolve_request_kwargs(timeout, metadata),
72+
)
73+
async for response in call:
74+
yield response
75+
76+
async def _stream_unary(
77+
self,
78+
stub_method: grpc.aio.StreamUnaryMultiCallable,
79+
request_iterator: MessageSource,
80+
*,
81+
timeout: Optional[float] = None,
82+
metadata: Optional[MetadataLike] = None,
83+
) -> "T":
84+
call = stub_method(
85+
self._wrap_message_iterator(request_iterator),
86+
**self._resolve_request_kwargs(timeout, metadata),
87+
)
88+
return await call
89+
90+
async def _stream_stream(
91+
self,
92+
stub_method: grpc.aio.StreamStreamMultiCallable,
93+
request_iterator: MessageSource,
94+
*,
95+
timeout: Optional[float] = None,
96+
metadata: Optional[MetadataLike] = None,
97+
) -> AsyncIterator["T"]:
98+
call = stub_method(
99+
self._wrap_message_iterator(request_iterator),
100+
**self._resolve_request_kwargs(timeout, metadata),
101+
)
102+
async for response in call:
103+
yield response
104+
105+
@staticmethod
106+
def _wrap_message_iterator(
107+
messages: MessageSource,
108+
) -> AsyncIterator["IProtoMessage"]:
109+
if hasattr(messages, '__aiter__'):
110+
async def async_wrapper():
111+
async for message in messages:
112+
yield message
113+
114+
return async_wrapper()
115+
else:
116+
async def sync_wrapper():
117+
for message in messages:
118+
yield message
119+
120+
return sync_wrapper()
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from abc import ABC, abstractmethod
2+
from typing import TYPE_CHECKING, Dict
3+
4+
5+
if TYPE_CHECKING:
6+
import grpc
7+
8+
9+
class ServiceBase(ABC):
10+
11+
@property
12+
@abstractmethod
13+
def __rpc_methods__(self) -> Dict[str, "grpc.RpcMethodHandler"]: ...
14+
15+
@property
16+
@abstractmethod
17+
def __proto_path__(self) -> str: ...
18+
19+
20+
def register_servicers(server: "grpc.aio.Server", *servicers: ServiceBase):
21+
from grpc import method_handlers_generic_handler
22+
23+
server.add_generic_rpc_handlers(
24+
tuple(
25+
method_handlers_generic_handler(
26+
servicer.__proto_path__, servicer.__rpc_methods__
27+
)
28+
for servicer in servicers
29+
)
30+
)

src/betterproto/plugin/models.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ class OutputTemplate:
270270
imports_type_checking_only: Set[str] = field(default_factory=set)
271271
pydantic_dataclasses: bool = False
272272
output: bool = True
273+
use_grpcio: bool = False
273274
typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler)
274275

275276
@property
@@ -697,18 +698,20 @@ class ServiceMethodCompiler(ProtoContentBase):
697698
proto_obj: MethodDescriptorProto
698699
path: List[int] = PLACEHOLDER
699700
comment_indent: int = 8
701+
use_grpcio: bool = False
700702

701703
def __post_init__(self) -> None:
702704
# Add method to service
703705
self.parent.methods.append(self)
704706

705-
self.output_file.imports_type_checking_only.add("import grpclib.server")
706-
self.output_file.imports_type_checking_only.add(
707-
"from betterproto.grpc.grpclib_client import MetadataLike"
708-
)
709-
self.output_file.imports_type_checking_only.add(
710-
"from grpclib.metadata import Deadline"
711-
)
707+
if self.use_grpcio:
708+
imports = ["import grpc.aio", "from betterproto.grpc.grpcio_client import MetadataLike"]
709+
else:
710+
imports = ["import grpclib.server", "from betterproto.grpc.grpclib_client import MetadataLike",
711+
"from grpclib.metadata import Deadline"]
712+
713+
for import_line in imports:
714+
self.output_file.imports_type_checking_only.add(import_line)
712715

713716
super().__post_init__() # check for unset fields
714717

src/betterproto/plugin/parser.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,11 @@
4040
from .typing_compiler import (
4141
DirectImportTypingCompiler,
4242
NoTyping310TypingCompiler,
43-
TypingCompiler,
4443
TypingImportTypingCompiler,
4544
)
4645

46+
USE_GRPCIO_FLAG = "USE_GRPCIO"
47+
4748

4849
def traverse(
4950
proto_file: FileDescriptorProto,
@@ -80,6 +81,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
8081
response.supported_features = CodeGeneratorResponseFeature.FEATURE_PROTO3_OPTIONAL
8182

8283
request_data = PluginRequestCompiler(plugin_request_obj=request)
84+
use_grpcio = USE_GRPCIO_FLAG in plugin_options
8385
# Gather output packages
8486
for proto_file in request.proto_file:
8587
output_package_name = proto_file.package
@@ -90,7 +92,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
9092
)
9193
# Add this input file to the output corresponding to this package
9294
request_data.output_packages[output_package_name].input_files.append(proto_file)
93-
95+
request_data.output_packages[output_package_name].use_grpcio = use_grpcio
9496
if (
9597
proto_file.package == "google.protobuf"
9698
and "INCLUDE_GOOGLE" not in plugin_options
@@ -143,7 +145,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
143145
for output_package_name, output_package in request_data.output_packages.items():
144146
for proto_input_file in output_package.input_files:
145147
for index, service in enumerate(proto_input_file.service):
146-
read_protobuf_service(proto_input_file, service, index, output_package)
148+
read_protobuf_service(proto_input_file, service, index, output_package, use_grpcio)
147149

148150
# Generate output files
149151
output_paths: Set[pathlib.Path] = set()
@@ -253,6 +255,7 @@ def read_protobuf_service(
253255
service: ServiceDescriptorProto,
254256
index: int,
255257
output_package: OutputTemplate,
258+
use_grpcio: bool = False,
256259
) -> None:
257260
service_data = ServiceCompiler(
258261
source_file=source_file,
@@ -266,4 +269,5 @@ def read_protobuf_service(
266269
parent=service_data,
267270
proto_obj=method,
268271
path=[6, index, 2, j],
272+
use_grpcio=use_grpcio,
269273
)

src/betterproto/plugin/typing_compiler.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ def async_iterable(self, type: str) -> str:
4141
def async_iterator(self, type: str) -> str:
4242
raise NotImplementedError()
4343

44+
@abc.abstractmethod
45+
def async_generator(self, type: str) -> str:
46+
raise NotImplementedError()
47+
4448
@abc.abstractmethod
4549
def imports(self) -> Dict[str, Optional[Set[str]]]:
4650
"""
@@ -93,6 +97,10 @@ def async_iterator(self, type: str) -> str:
9397
self._imports["typing"].add("AsyncIterator")
9498
return f"AsyncIterator[{type}]"
9599

100+
def async_generator(self, type: str) -> str:
101+
self._imports["typing"].add("AsyncGenerator")
102+
return f"AsyncGenerator[{type}, None]"
103+
96104
def imports(self) -> Dict[str, Optional[Set[str]]]:
97105
return {k: v if v else None for k, v in self._imports.items()}
98106

@@ -129,6 +137,10 @@ def async_iterator(self, type: str) -> str:
129137
self._imported = True
130138
return f"typing.AsyncIterator[{type}]"
131139

140+
def async_generator(self, type: str) -> str:
141+
self._imported = True
142+
return f"typing.AsyncGenerator[{type}, None]"
143+
132144
def imports(self) -> Dict[str, Optional[Set[str]]]:
133145
if self._imported:
134146
return {"typing": None}
@@ -169,5 +181,9 @@ def async_iterator(self, type: str) -> str:
169181
self._imports["collections.abc"].add("AsyncIterator")
170182
return f'"AsyncIterator[{type}]"'
171183

184+
def async_generator(self, type: str) -> str:
185+
self._imports["collections.abc"].add("AsyncGenerator")
186+
return f'"AsyncGenerator[{type}, None]"'
187+
172188
def imports(self) -> Dict[str, Optional[Set[str]]]:
173189
return {k: v if v else None for k, v in self._imports.items()}

src/betterproto/templates/header.py.j2

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ __all__ = (
1313
{%- for service in output_file.services -%}
1414
"{{ service.py_name }}Stub",
1515
"{{ service.py_name }}Base",
16+
{%- if output_file.use_grpcio -%}
17+
"add_{{ service.py_name }}Servicer_to_server",
18+
{%- endif -%}
1619
{%- endfor -%}
1720
)
1821

@@ -29,7 +32,7 @@ from dataclasses import dataclass
2932
{% if output_file.datetime_imports %}
3033
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
3134

32-
{% endif%}
35+
{% endif %}
3336
{% set typing_imports = output_file.typing_compiler.imports() %}
3437
{% if typing_imports %}
3538
{% for line in output_file.typing_compiler.import_lines() %}
@@ -42,8 +45,16 @@ from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% i
4245

4346
{% endif %}
4447

48+
49+
{% if output_file.use_grpcio %}
50+
import grpc
51+
from betterproto.grpc.grpcio_client import ServiceStub
52+
from betterproto.grpc.grpcio_server import ServiceBase
53+
{% endif %}
54+
4555
import betterproto
46-
{% if output_file.services %}
56+
{% if not output_file.use_grpcio %}
57+
from betterproto.grpc.grpclib_client import ServiceStub
4758
from betterproto.grpc.grpclib_server import ServiceBase
4859
import grpclib
4960
{% endif %}

0 commit comments

Comments
 (0)