Skip to content

Commit bb61dad

Browse files
committed
Refactor.
1 parent 162d7e6 commit bb61dad

File tree

4 files changed

+49
-64
lines changed

4 files changed

+49
-64
lines changed

src/betterproto/grpc/grpcio_client.py

Lines changed: 35 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,12 @@
1111

1212
import grpc
1313

14-
1514
if TYPE_CHECKING:
1615
from .._types import (
1716
T,
1817
IProtoMessage,
1918
)
2019

21-
2220
Value = Union[str, bytes]
2321
MetadataLike = Union[Mapping[str, Value], Iterable[tuple[str, Value]]]
2422
MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]]
@@ -27,56 +25,46 @@
2725
class ServiceStub(ABC):
2826

2927
def __init__(
30-
self,
31-
channel: grpc.aio.Channel,
32-
*,
33-
timeout: Optional[float] = None,
34-
metadata: Optional[MetadataLike] = None,
28+
self,
29+
channel: grpc.aio.Channel,
30+
*,
31+
timeout: Optional[float] = None,
32+
metadata: Optional[MetadataLike] = None,
3533
) -> None:
3634
self.channel = channel
3735
self.timeout = timeout
3836
self.metadata = metadata
3937

4038
def _resolve_request_kwargs(
41-
self,
42-
timeout: Optional[float],
43-
metadata: Optional[MetadataLike],
39+
self,
40+
timeout: Optional[float],
41+
metadata: Optional[MetadataLike],
4442
):
45-
# Avoid creating dict if no overrides needed
46-
if timeout is None and metadata is None:
47-
# Return cached default kwargs if both timeout and metadata are None
48-
if not hasattr(self, '_default_kwargs'):
49-
self._default_kwargs = {
50-
"timeout": self.timeout,
51-
"metadata": self.metadata,
52-
}
53-
return self._default_kwargs
54-
5543
return {
5644
"timeout": self.timeout if timeout is None else timeout,
5745
"metadata": self.metadata if metadata is None else metadata,
5846
}
5947

6048
async def _unary_unary(
61-
self,
62-
stub_method: grpc.aio.UnaryUnaryMultiCallable,
63-
request: "IProtoMessage",
64-
*,
65-
timeout: Optional[float] = None,
66-
metadata: Optional[MetadataLike] = None,
49+
self,
50+
stub_method: grpc.aio.UnaryUnaryMultiCallable,
51+
request: "IProtoMessage",
52+
*,
53+
timeout: Optional[float] = None,
54+
metadata: Optional[MetadataLike] = None,
6755
) -> "T":
6856
return await stub_method(
6957
request,
7058
**self._resolve_request_kwargs(timeout, metadata),
7159
)
7260

7361
async def _unary_stream(
74-
self,
75-
stub_method: grpc.aio.UnaryStreamMultiCallable,
76-
request: "IProtoMessage",
77-
*,
78-
timeout: Optional[float] = None,
79-
metadata: Optional[MetadataLike] = None,
62+
self,
63+
stub_method: grpc.aio.UnaryStreamMultiCallable,
64+
request: "IProtoMessage",
65+
*,
66+
timeout: Optional[float] = None,
67+
metadata: Optional[MetadataLike] = None,
8068
) -> AsyncIterator["T"]:
8169
call = stub_method(
8270
request,
@@ -86,12 +74,12 @@ async def _unary_stream(
8674
yield response
8775

8876
async def _stream_unary(
89-
self,
90-
stub_method: grpc.aio.StreamUnaryMultiCallable,
91-
request_iterator: MessageSource,
92-
*,
93-
timeout: Optional[float] = None,
94-
metadata: Optional[MetadataLike] = None,
77+
self,
78+
stub_method: grpc.aio.StreamUnaryMultiCallable,
79+
request_iterator: MessageSource,
80+
*,
81+
timeout: Optional[float] = None,
82+
metadata: Optional[MetadataLike] = None,
9583
) -> "T":
9684
call = stub_method(
9785
self._wrap_message_iterator(request_iterator),
@@ -100,12 +88,12 @@ async def _stream_unary(
10088
return await call
10189

10290
async def _stream_stream(
103-
self,
104-
stub_method: grpc.aio.StreamStreamMultiCallable,
105-
request_iterator: MessageSource,
106-
*,
107-
timeout: Optional[float] = None,
108-
metadata: Optional[MetadataLike] = None,
91+
self,
92+
stub_method: grpc.aio.StreamStreamMultiCallable,
93+
request_iterator: MessageSource,
94+
*,
95+
timeout: Optional[float] = None,
96+
metadata: Optional[MetadataLike] = None,
10997
) -> AsyncIterator["T"]:
11098
call = stub_method(
11199
self._wrap_message_iterator(request_iterator),
@@ -116,15 +104,17 @@ async def _stream_stream(
116104

117105
@staticmethod
118106
def _wrap_message_iterator(
119-
messages: MessageSource,
107+
messages: MessageSource,
120108
) -> AsyncIterator["IProtoMessage"]:
121109
if hasattr(messages, '__aiter__'):
122110
async def async_wrapper():
123111
async for message in messages:
124112
yield message
113+
125114
return async_wrapper()
126115
else:
127116
async def sync_wrapper():
128117
for message in messages:
129118
yield message
119+
130120
return sync_wrapper()

src/betterproto/plugin/models.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -703,19 +703,15 @@ class ServiceMethodCompiler(ProtoContentBase):
703703
def __post_init__(self) -> None:
704704
# Add method to service
705705
self.parent.methods.append(self)
706+
706707
if self.use_grpcio:
707-
self.output_file.imports_type_checking_only.add("import grpc.aio")
708-
self.output_file.imports_type_checking_only.add(
709-
"from betterproto.grpc.grpcio_client import MetadataLike"
710-
)
708+
imports = ["import grpc.aio", "from betterproto.grpc.grpcio_client import MetadataLike"]
711709
else:
712-
self.output_file.imports_type_checking_only.add("import grpclib.server")
713-
self.output_file.imports_type_checking_only.add(
714-
"from betterproto.grpc.grpclib_client import MetadataLike"
715-
)
716-
self.output_file.imports_type_checking_only.add(
717-
"from grpclib.metadata import Deadline"
718-
)
710+
imports = ["import grpclib.server", "from betterproto.grpc.grpclib_client import MetadataLike",
711+
"from grpclib.metadata import Deadline"]
712+
713+
for import_line in imports:
714+
self.output_file.imports_type_checking_only.add(import_line)
719715

720716
super().__post_init__() # check for unset fields
721717

src/betterproto/plugin/parser.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,11 @@
4040
from .typing_compiler import (
4141
DirectImportTypingCompiler,
4242
NoTyping310TypingCompiler,
43-
TypingCompiler,
4443
TypingImportTypingCompiler,
4544
)
4645

46+
USE_GRPCIO_FLAG = "USE_GRPCIO"
47+
4748

4849
def traverse(
4950
proto_file: FileDescriptorProto,
@@ -80,7 +81,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
8081
response.supported_features = CodeGeneratorResponseFeature.FEATURE_PROTO3_OPTIONAL
8182

8283
request_data = PluginRequestCompiler(plugin_request_obj=request)
83-
use_grpcio = "USE_GRPCIO" in plugin_options
84+
use_grpcio = USE_GRPCIO_FLAG in plugin_options
8485
# Gather output packages
8586
for proto_file in request.proto_file:
8687
output_package_name = proto_file.package
@@ -91,8 +92,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
9192
)
9293
# Add this input file to the output corresponding to this package
9394
request_data.output_packages[output_package_name].input_files.append(proto_file)
94-
if use_grpcio:
95-
request_data.output_packages[output_package_name].use_grpcio = True
95+
request_data.output_packages[output_package_name].use_grpcio = use_grpcio
9696
if (
9797
proto_file.package == "google.protobuf"
9898
and "INCLUDE_GOOGLE" not in plugin_options
@@ -145,9 +145,7 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
145145
for output_package_name, output_package in request_data.output_packages.items():
146146
for proto_input_file in output_package.input_files:
147147
for index, service in enumerate(proto_input_file.service):
148-
read_protobuf_service(
149-
proto_input_file, service, index, output_package, use_grpcio
150-
)
148+
read_protobuf_service(proto_input_file, service, index, output_package, use_grpcio)
151149

152150
# Generate output files
153151
output_paths: Set[pathlib.Path] = set()

tests/util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Union,
1919
)
2020

21+
from betterproto.plugin.parser import USE_GRPCIO_FLAG
2122

2223
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
2324

@@ -82,7 +83,7 @@ async def protoc(
8283
if pydantic_dataclasses:
8384
command.append("--custom_opt=pydantic_dataclasses")
8485
if grpcio:
85-
command.append("--custom_opt=USE_GRPCIO")
86+
command.append(f"--custom_opt={USE_GRPCIO_FLAG}")
8687

8788
command.extend(
8889
[

0 commit comments

Comments
 (0)