Skip to content

Commit f2969e0

Browse files
AdrienVannsonmokucher
authored andcommitted
Fix README and AsyncGenerator imports
1 parent bb61dad commit f2969e0

File tree

5 files changed

+37
-6
lines changed

5 files changed

+37
-6
lines changed

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,10 +538,14 @@ protoc \
538538

539539
### Using grpcio library instead of grpclib
540540

541-
In order to use the `grpcio` library instead of `grpclib`, you can use the `--custom_opt=grpcio`
541+
In order to use the `grpcio` library instead of `grpclib`, you can use the `--custom_opt=USE_GRPCIO`
542542
option when running the `protoc` command.
543543
This will generate stubs compatible with the `grpcio` library.
544544

545+
Example:
546+
```sh
547+
protoc --custom_opt=USE_GRPCIO -I . --custom_out=generated --plugin=protoc-gen-custom=src/betterproto/plugin/main.py demo.proto
548+
```
545549
### TODO
546550

547551
- [x] Fixed length fields

src/betterproto/plugin/typing_compiler.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ def async_iterable(self, type: str) -> str:
4141
def async_iterator(self, type: str) -> str:
4242
raise NotImplementedError()
4343

44+
@abc.abstractmethod
45+
def async_generator(self, type: str) -> str:
46+
raise NotImplementedError()
47+
4448
@abc.abstractmethod
4549
def imports(self) -> Dict[str, Optional[Set[str]]]:
4650
"""
@@ -93,6 +97,10 @@ def async_iterator(self, type: str) -> str:
9397
self._imports["typing"].add("AsyncIterator")
9498
return f"AsyncIterator[{type}]"
9599

100+
def async_generator(self, type: str) -> str:
101+
self._imports["typing"].add("AsyncGenerator")
102+
return f"AsyncGenerator[{type}, None]"
103+
96104
def imports(self) -> Dict[str, Optional[Set[str]]]:
97105
return {k: v if v else None for k, v in self._imports.items()}
98106

@@ -129,6 +137,10 @@ def async_iterator(self, type: str) -> str:
129137
self._imported = True
130138
return f"typing.AsyncIterator[{type}]"
131139

140+
def async_generator(self, type: str) -> str:
141+
self._imported = True
142+
return f"typing.AsyncGenerator[{type}, None]"
143+
132144
def imports(self) -> Dict[str, Optional[Set[str]]]:
133145
if self._imported:
134146
return {"typing": None}
@@ -169,5 +181,9 @@ def async_iterator(self, type: str) -> str:
169181
self._imports["collections.abc"].add("AsyncIterator")
170182
return f'"AsyncIterator[{type}]"'
171183

184+
def async_generator(self, type: str) -> str:
185+
self._imports["collections.abc"].add("AsyncGenerator")
186+
return f'"AsyncGenerator[{type}, None]"'
187+
172188
def imports(self) -> Dict[str, Optional[Set[str]]]:
173189
return {k: v if v else None for k, v in self._imports.items()}

src/betterproto/templates/header.py.j2

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ __all__ = (
1313
{%- for service in output_file.services -%}
1414
"{{ service.py_name }}Stub",
1515
"{{ service.py_name }}Base",
16+
{%- if output_file.use_grpcio -%}
17+
"add_{{ service.py_name }}Servicer_to_server",
18+
{%- endif -%}
1619
{%- endfor -%}
1720
)
1821

@@ -29,7 +32,7 @@ from dataclasses import dataclass
2932
{% if output_file.datetime_imports %}
3033
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
3134

32-
{% endif%}
35+
{% endif %}
3336
{% set typing_imports = output_file.typing_compiler.imports() %}
3437
{% if typing_imports %}
3538
{% for line in output_file.typing_compiler.import_lines() %}

src/betterproto/templates/template.py.j2

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
133133
metadata=metadata,
134134
)
135135
{% else %}
136-
# unary_unary pour grpcio - direct call
137136
return await self.channel.unary_unary(
138137
"{{ method.route }}",
139138
request_serializer={{ method.py_input_message_type }}.SerializeToString,
@@ -284,16 +283,20 @@ class {{ service.py_name }}Base(ServiceBase):
284283

285284
{% endif %}
286285

286+
@property
287+
def __proto_path__(self) -> str:
288+
return "{% if output_file.package %}{{ output_file.package }}.{% endif %}{{ service.proto_name }}"
289+
287290
{% for method in service.methods %}
288291
async def {{ method.py_name }}(self
289292
{%- if not method.client_streaming -%}
290293
, request: "{{ method.py_input_message_type }}"
291294
{%- else -%}
292295
{# Client streaming: need a request iterator instead #}
293-
, request_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
296+
, request_iterator: "{{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }}"
294297
{%- endif -%}
295298
, context: grpc.aio.ServicerContext
296-
) -> {% if method.server_streaming %}AsyncGenerator["{{ method.py_output_message_type }}", None]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
299+
) -> {% if method.server_streaming %}"{{ output_file.typing_compiler.async_generator(method.py_output_message_type) }}"{% else %}"{{ method.py_output_message_type }}"{% endif %}:
297300
{% if method.comment %}
298301
{{ method.comment }}
299302

tests/test_typing_compiler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def test_direct_import_typing_compiler():
2727
"typing": {"Optional", "List", "Dict", "Union", "Iterable", "AsyncIterable"}
2828
}
2929
assert compiler.async_iterator("str") == "AsyncIterator[str]"
30+
assert compiler.async_generator("str") == "AsyncGenerator[str, None]"
3031
assert compiler.imports() == {
3132
"typing": {
3233
"Optional",
@@ -36,6 +37,7 @@ def test_direct_import_typing_compiler():
3637
"Iterable",
3738
"AsyncIterable",
3839
"AsyncIterator",
40+
"AsyncGenerator",
3941
}
4042
}
4143

@@ -57,6 +59,8 @@ def test_typing_import_typing_compiler():
5759
assert compiler.imports() == {"typing": None}
5860
assert compiler.async_iterator("str") == "typing.AsyncIterator[str]"
5961
assert compiler.imports() == {"typing": None}
62+
assert compiler.async_generator("str") == "typing.AsyncGenerator[str, None]"
63+
assert compiler.imports() == {"typing": None}
6064

6165

6266
def test_no_typing_311_typing_compiler():
@@ -73,6 +77,7 @@ def test_no_typing_311_typing_compiler():
7377
assert compiler.iterable("str") == '"Iterable[str]"'
7478
assert compiler.async_iterable("str") == '"AsyncIterable[str]"'
7579
assert compiler.async_iterator("str") == '"AsyncIterator[str]"'
80+
assert compiler.async_generator("str") == '"AsyncGenerator[str, None]"'
7681
assert compiler.imports() == {
77-
"collections.abc": {"Iterable", "AsyncIterable", "AsyncIterator"}
82+
"collections.abc": {"Iterable", "AsyncIterable", "AsyncIterator", "AsyncGenerator"}
7883
}

0 commit comments

Comments
 (0)