Skip to content

Commit 67280d6

Browse files
Merge pull request #79 from wesky93/develop
Allow configuration of message_to_dict
2 parents d9de93a + f383892 commit 67280d6

File tree

10 files changed

+490
-207
lines changed

10 files changed

+490
-207
lines changed

CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,17 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [0.1.17](https://github.com/wesky93/grpc_requests/releases/tag/v0.1.17) - 2024-04-22
9+
10+
## Added
11+
12+
- Support for custom message parsing in both async and sync clients
13+
14+
## Removed
15+
16+
- Removed singular FileDescriptor getter methods and Method specific field descriptor
17+
methods as laid out previously.
18+
819
## [0.1.16](https://github.com/wesky93/grpc_requests/releases/tag/v0.1.16) - 2024-03-03
920

1021
## Added

src/examples/README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,26 @@ result = await greeter.HelloEveryone(requests_data)
123123
results = [x async for x in await greeter.SayHelloOneByOne(requests_data)]
124124
```
125125

126+
## Setting a Client's message_to_dict behavior
127+
128+
By utilizing `CustomArgumentParsers`, behavioral arguments can be passed to
129+
message_to_dict at time of Client instantiation. This is available for both
130+
synchronous and asynchronous clients.
131+
132+
```python
133+
client = Client(
134+
"localhost:50051",
135+
message_parsers=CustomArgumentParsers(
136+
message_to_dict_kwargs={
137+
"preserving_proto_field_name": True,
138+
"including_default_value_fields": True,
139+
}
140+
),
141+
)
142+
```
143+
144+
[Review the json_format documentation for what kwargs are available to message_to_dict.](https://googleapis.dev/python/protobuf/latest/google/protobuf/json_format.html)
145+
126146
## Retrieving Information about a Server
127147

128148
All forms of clients expose methods to allow a user to query a server about its

src/grpc_requests/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@
77
)
88
from .client import Client, ReflectionClient, StubClient, get_by_endpoint
99

10-
__version__ = "0.1.16"
10+
__version__ = "0.1.17"

src/grpc_requests/aio.py

Lines changed: 88 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
22
import sys
3-
import warnings
43
from enum import Enum
54
from functools import partial
65
from typing import (
@@ -18,10 +17,16 @@
1817
import grpc
1918
from google.protobuf import (
2019
descriptor_pb2,
20+
message_factory,
21+
)
22+
from google.protobuf import (
2123
descriptor_pool as _descriptor_pool,
24+
)
25+
from google.protobuf import (
2226
symbol_database as _symbol_database,
23-
message_factory,
24-
) # noqa: E501
27+
)
28+
29+
# noqa: E501
2530
from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor
2631
from google.protobuf.descriptor_pb2 import ServiceDescriptorProto
2732
from google.protobuf.json_format import MessageToDict, ParseDict
@@ -34,11 +39,13 @@
3439

3540
if sys.version_info >= (3, 8):
3641
import importlib.metadata
42+
from typing import Protocol
3743

3844
def get_metadata(package_name: str):
3945
return importlib.metadata.version(package_name)
4046
else:
4147
import pkg_resources
48+
from typing_extensions import Protocol
4249

4350
def get_metadata(package_name: str):
4451
return pkg_resources.get_distribution(package_name).version
@@ -146,27 +153,67 @@ def __del__(self):
146153
pass
147154

148155

149-
def parse_request_data(reqeust_data, input_type):
150-
_data = reqeust_data or {}
151-
if isinstance(_data, dict):
152-
request = ParseDict(_data, input_type())
153-
else:
154-
request = _data
155-
return request
156+
class MessageParsersProtocol(Protocol):
157+
def parse_request_data(self, request_data, input_type): ...
156158

159+
def parse_stream_requests(self, stream_requests_data: Iterable, input_type): ...
157160

158-
def parse_stream_requests(stream_requests_data: Iterable, input_type):
159-
for request_data in stream_requests_data:
160-
yield parse_request_data(request_data or {}, input_type)
161+
async def parse_response(self, response): ...
161162

163+
async def parse_stream_responses(self, responses: AsyncIterable): ...
164+
165+
166+
class MessageParsers(MessageParsersProtocol):
167+
def parse_request_data(self, request_data, input_type):
168+
_data = request_data or {}
169+
if isinstance(_data, dict):
170+
request = ParseDict(_data, input_type())
171+
else:
172+
request = _data
173+
return request
174+
175+
def parse_stream_requests(self, stream_requests_data: Iterable, input_type):
176+
for request_data in stream_requests_data:
177+
yield self.parse_request_data(request_data or {}, input_type)
178+
179+
async def parse_response(self, response):
180+
return MessageToDict(response, preserving_proto_field_name=True)
181+
182+
async def parse_stream_responses(self, responses: AsyncIterable):
183+
async for resp in responses:
184+
yield await self.parse_response(resp)
185+
186+
187+
class CustomArgumentParsers(MessageParsersProtocol):
188+
_message_to_dict_kwargs: Dict[str, Any]
189+
_parse_dict_kwargs: Dict[str, Any]
190+
191+
def __init__(
192+
self,
193+
message_to_dict_kwargs: Dict[str, Any] = dict(),
194+
parse_dict_kwargs: Dict[str, Any] = dict(),
195+
):
196+
self._message_to_dict_kwargs = message_to_dict_kwargs or {}
197+
self._parse_dict_kwargs = parse_dict_kwargs or {}
198+
199+
def parse_request_data(self, request_data, input_type):
200+
_data = request_data or {}
201+
if isinstance(_data, dict):
202+
request = ParseDict(_data, input_type(), **self._parse_dict_kwargs)
203+
else:
204+
request = _data
205+
return request
162206

163-
async def parse_response(response):
164-
return MessageToDict(response, preserving_proto_field_name=True)
207+
def parse_stream_requests(self, stream_requests_data: Iterable, input_type):
208+
for request_data in stream_requests_data:
209+
yield self.parse_request_data(request_data or {}, input_type)
165210

211+
async def parse_response(self, response):
212+
return MessageToDict(response, **self._message_to_dict_kwargs)
166213

167-
async def parse_stream_responses(responses: AsyncIterable):
168-
async for resp in responses:
169-
yield await parse_response(resp)
214+
async def parse_stream_responses(self, responses: AsyncIterable):
215+
async for resp in responses:
216+
yield await self.parse_response(resp)
170217

171218

172219
class MethodType(Enum):
@@ -179,25 +226,32 @@ class MethodType(Enum):
179226
def is_unary_request(self):
180227
return "unary_" in self.value
181228

182-
@property
183-
def request_parser(self):
184-
return parse_request_data if self.is_unary_request else parse_stream_requests
185-
186229
@property
187230
def is_unary_response(self):
188231
return "_unary" in self.value
189232

190-
@property
191-
def response_parser(self):
192-
return parse_response if self.is_unary_response else parse_stream_responses
193-
194233

195234
class MethodMetaData(NamedTuple):
196235
input_type: Any
197236
output_type: Any
198237
method_type: MethodType
199238
handler: Any
200239
descriptor: MethodDescriptor
240+
parsers: MessageParsersProtocol
241+
242+
@property
243+
def request_parser(self):
244+
if self.method_type.is_unary_request:
245+
return self.parsers.parse_request_data
246+
else:
247+
return self.parsers.parse_stream_requests
248+
249+
@property
250+
def response_parser(self):
251+
if self.method_type.is_unary_response:
252+
return self.parsers.parse_response
253+
else:
254+
return self.parsers.parse_stream_responses
201255

202256

203257
IS_REQUEST_STREAM = TypeVar("IS_REQUEST_STREAM")
@@ -220,6 +274,7 @@ def __init__(
220274
ssl=False,
221275
compression=None,
222276
skip_check_method_available=False,
277+
message_parsers: MessageParsersProtocol = MessageParsers(),
223278
**kwargs,
224279
):
225280
super().__init__(
@@ -233,6 +288,7 @@ def __init__(
233288
self._service_names: list = None
234289
self.has_server_registered = False
235290
self._skip_check_method_available = skip_check_method_available
291+
self._message_parsers = message_parsers
236292
self._services_module_name = {}
237293
self._service_methods_meta: Dict[str, Dict[str, MethodMetaData]] = {}
238294

@@ -309,6 +365,7 @@ def _register_methods(
309365
output_type=output_type,
310366
handler=handler,
311367
descriptor=method_desc,
368+
parsers=self._message_parsers,
312369
)
313370
return metadata
314371

@@ -348,19 +405,17 @@ async def _request(self, service, method, request, raw_output=False, **kwargs):
348405
# does not check request is available
349406
method_meta = self.get_method_meta(service, method)
350407

351-
_request = method_meta.method_type.request_parser(
352-
request, method_meta.input_type
353-
)
408+
_request = method_meta.request_parser(request, method_meta.input_type)
354409
if method_meta.method_type.is_unary_response:
355410
result = await method_meta.handler(_request, **kwargs)
356411

357412
if raw_output:
358413
return result
359414
else:
360-
return await method_meta.method_type.response_parser(result)
415+
return await method_meta.response_parser(result)
361416
else:
362417
result = method_meta.handler(_request, **kwargs)
363-
return method_meta.method_type.response_parser(result)
418+
return method_meta.response_parser(result)
364419

365420
async def request(self, service, method, request=None, raw_output=False, **kwargs):
366421
await self.check_method_available(service, method)
@@ -427,6 +482,7 @@ def __init__(
427482
descriptor_pool=None,
428483
ssl=False,
429484
compression=None,
485+
message_parsers: MessageParsersProtocol = MessageParsers(),
430486
**kwargs,
431487
):
432488
super().__init__(
@@ -435,6 +491,7 @@ def __init__(
435491
descriptor_pool,
436492
ssl=ssl,
437493
compression=compression,
494+
message_parsers=message_parsers,
438495
**kwargs,
439496
)
440497
self.reflection_stub = reflection_pb2_grpc.ServerReflectionStub(self.channel)
@@ -453,26 +510,6 @@ async def _get_service_names(self):
453510
services = tuple([s.name for s in resp.list_services_response.service])
454511
return services
455512

456-
async def get_file_descriptor_by_name(self, name):
457-
warnings.warn(
458-
"This function is deprecated, and will be removed in the 0.1.17 release. Use get_file_descriptors_by_name() instead.",
459-
DeprecationWarning,
460-
)
461-
request = reflection_pb2.ServerReflectionRequest(file_by_filename=name)
462-
result = await self._reflection_single_request(request)
463-
proto = result.file_descriptor_response.file_descriptor_proto[0]
464-
return descriptor_pb2.FileDescriptorProto.FromString(proto)
465-
466-
async def get_file_descriptor_by_symbol(self, symbol):
467-
warnings.warn(
468-
"This function is deprecated, and will be removed in the 0.1.17 release. Use get_file_descriptors_by_symbol() instead.",
469-
DeprecationWarning,
470-
)
471-
request = reflection_pb2.ServerReflectionRequest(file_containing_symbol=symbol)
472-
result = await self._reflection_single_request(request)
473-
proto = result.file_descriptor_response.file_descriptor_proto[0]
474-
return descriptor_pb2.FileDescriptorProto.FromString(proto)
475-
476513
async def get_file_descriptors_by_name(self, name):
477514
request = reflection_pb2.ServerReflectionRequest(file_by_filename=name)
478515
result = await self._reflection_single_request(request)

0 commit comments

Comments
 (0)