11import logging
22import sys
3- import warnings
43from enum import Enum
54from functools import partial
65from typing import (
1817import grpc
1918from google .protobuf import (
2019 descriptor_pb2 ,
20+ message_factory ,
21+ )
22+ from google .protobuf import (
2123 descriptor_pool as _descriptor_pool ,
24+ )
25+ from google .protobuf import (
2226 symbol_database as _symbol_database ,
23- message_factory ,
24- ) # noqa: E501
27+ )
28+
29+ # noqa: E501
2530from google .protobuf .descriptor import MethodDescriptor , ServiceDescriptor
2631from google .protobuf .descriptor_pb2 import ServiceDescriptorProto
2732from google .protobuf .json_format import MessageToDict , ParseDict
3439
3540if sys .version_info >= (3 , 8 ):
3641 import importlib .metadata
42+ from typing import Protocol
3743
3844 def get_metadata (package_name : str ):
3945 return importlib .metadata .version (package_name )
4046else :
4147 import pkg_resources
48+ from typing_extensions import Protocol
4249
4350 def get_metadata (package_name : str ):
4451 return pkg_resources .get_distribution (package_name ).version
@@ -146,27 +153,67 @@ def __del__(self):
146153 pass
147154
148155
149- def parse_request_data (reqeust_data , input_type ):
150- _data = reqeust_data or {}
151- if isinstance (_data , dict ):
152- request = ParseDict (_data , input_type ())
153- else :
154- request = _data
155- return request
156+ class MessageParsersProtocol (Protocol ):
157+ def parse_request_data (self , request_data , input_type ): ...
156158
159+ def parse_stream_requests (self , stream_requests_data : Iterable , input_type ): ...
157160
158- def parse_stream_requests (stream_requests_data : Iterable , input_type ):
159- for request_data in stream_requests_data :
160- yield parse_request_data (request_data or {}, input_type )
161+ async def parse_response (self , response ): ...
161162
163+ async def parse_stream_responses (self , responses : AsyncIterable ): ...
164+
165+
166+ class MessageParsers (MessageParsersProtocol ):
167+ def parse_request_data (self , request_data , input_type ):
168+ _data = request_data or {}
169+ if isinstance (_data , dict ):
170+ request = ParseDict (_data , input_type ())
171+ else :
172+ request = _data
173+ return request
174+
175+ def parse_stream_requests (self , stream_requests_data : Iterable , input_type ):
176+ for request_data in stream_requests_data :
177+ yield self .parse_request_data (request_data or {}, input_type )
178+
179+ async def parse_response (self , response ):
180+ return MessageToDict (response , preserving_proto_field_name = True )
181+
182+ async def parse_stream_responses (self , responses : AsyncIterable ):
183+ async for resp in responses :
184+ yield await self .parse_response (resp )
185+
186+
187+ class CustomArgumentParsers (MessageParsersProtocol ):
188+ _message_to_dict_kwargs : Dict [str , Any ]
189+ _parse_dict_kwargs : Dict [str , Any ]
190+
191+ def __init__ (
192+ self ,
193+ message_to_dict_kwargs : Dict [str , Any ] = dict (),
194+ parse_dict_kwargs : Dict [str , Any ] = dict (),
195+ ):
196+ self ._message_to_dict_kwargs = message_to_dict_kwargs or {}
197+ self ._parse_dict_kwargs = parse_dict_kwargs or {}
198+
199+ def parse_request_data (self , request_data , input_type ):
200+ _data = request_data or {}
201+ if isinstance (_data , dict ):
202+ request = ParseDict (_data , input_type (), ** self ._parse_dict_kwargs )
203+ else :
204+ request = _data
205+ return request
162206
163- async def parse_response (response ):
164- return MessageToDict (response , preserving_proto_field_name = True )
207+ def parse_stream_requests (self , stream_requests_data : Iterable , input_type ):
208+ for request_data in stream_requests_data :
209+ yield self .parse_request_data (request_data or {}, input_type )
165210
211+ async def parse_response (self , response ):
212+ return MessageToDict (response , ** self ._message_to_dict_kwargs )
166213
167- async def parse_stream_responses (responses : AsyncIterable ):
168- async for resp in responses :
169- yield await parse_response (resp )
214+ async def parse_stream_responses (self , responses : AsyncIterable ):
215+ async for resp in responses :
216+ yield await self . parse_response (resp )
170217
171218
172219class MethodType (Enum ):
@@ -179,25 +226,32 @@ class MethodType(Enum):
179226 def is_unary_request (self ):
180227 return "unary_" in self .value
181228
182- @property
183- def request_parser (self ):
184- return parse_request_data if self .is_unary_request else parse_stream_requests
185-
186229 @property
187230 def is_unary_response (self ):
188231 return "_unary" in self .value
189232
190- @property
191- def response_parser (self ):
192- return parse_response if self .is_unary_response else parse_stream_responses
193-
194233
195234class MethodMetaData (NamedTuple ):
196235 input_type : Any
197236 output_type : Any
198237 method_type : MethodType
199238 handler : Any
200239 descriptor : MethodDescriptor
240+ parsers : MessageParsersProtocol
241+
242+ @property
243+ def request_parser (self ):
244+ if self .method_type .is_unary_request :
245+ return self .parsers .parse_request_data
246+ else :
247+ return self .parsers .parse_stream_requests
248+
249+ @property
250+ def response_parser (self ):
251+ if self .method_type .is_unary_response :
252+ return self .parsers .parse_response
253+ else :
254+ return self .parsers .parse_stream_responses
201255
202256
203257IS_REQUEST_STREAM = TypeVar ("IS_REQUEST_STREAM" )
@@ -220,6 +274,7 @@ def __init__(
220274 ssl = False ,
221275 compression = None ,
222276 skip_check_method_available = False ,
277+ message_parsers : MessageParsersProtocol = MessageParsers (),
223278 ** kwargs ,
224279 ):
225280 super ().__init__ (
@@ -233,6 +288,7 @@ def __init__(
233288 self ._service_names : list = None
234289 self .has_server_registered = False
235290 self ._skip_check_method_available = skip_check_method_available
291+ self ._message_parsers = message_parsers
236292 self ._services_module_name = {}
237293 self ._service_methods_meta : Dict [str , Dict [str , MethodMetaData ]] = {}
238294
@@ -309,6 +365,7 @@ def _register_methods(
309365 output_type = output_type ,
310366 handler = handler ,
311367 descriptor = method_desc ,
368+ parsers = self ._message_parsers ,
312369 )
313370 return metadata
314371
@@ -348,19 +405,17 @@ async def _request(self, service, method, request, raw_output=False, **kwargs):
348405 # does not check request is available
349406 method_meta = self .get_method_meta (service , method )
350407
351- _request = method_meta .method_type .request_parser (
352- request , method_meta .input_type
353- )
408+ _request = method_meta .request_parser (request , method_meta .input_type )
354409 if method_meta .method_type .is_unary_response :
355410 result = await method_meta .handler (_request , ** kwargs )
356411
357412 if raw_output :
358413 return result
359414 else :
360- return await method_meta .method_type . response_parser (result )
415+ return await method_meta .response_parser (result )
361416 else :
362417 result = method_meta .handler (_request , ** kwargs )
363- return method_meta .method_type . response_parser (result )
418+ return method_meta .response_parser (result )
364419
365420 async def request (self , service , method , request = None , raw_output = False , ** kwargs ):
366421 await self .check_method_available (service , method )
@@ -427,6 +482,7 @@ def __init__(
427482 descriptor_pool = None ,
428483 ssl = False ,
429484 compression = None ,
485+ message_parsers : MessageParsersProtocol = MessageParsers (),
430486 ** kwargs ,
431487 ):
432488 super ().__init__ (
@@ -435,6 +491,7 @@ def __init__(
435491 descriptor_pool ,
436492 ssl = ssl ,
437493 compression = compression ,
494+ message_parsers = message_parsers ,
438495 ** kwargs ,
439496 )
440497 self .reflection_stub = reflection_pb2_grpc .ServerReflectionStub (self .channel )
@@ -453,26 +510,6 @@ async def _get_service_names(self):
453510 services = tuple ([s .name for s in resp .list_services_response .service ])
454511 return services
455512
456- async def get_file_descriptor_by_name (self , name ):
457- warnings .warn (
458- "This function is deprecated, and will be removed in the 0.1.17 release. Use get_file_descriptors_by_name() instead." ,
459- DeprecationWarning ,
460- )
461- request = reflection_pb2 .ServerReflectionRequest (file_by_filename = name )
462- result = await self ._reflection_single_request (request )
463- proto = result .file_descriptor_response .file_descriptor_proto [0 ]
464- return descriptor_pb2 .FileDescriptorProto .FromString (proto )
465-
466- async def get_file_descriptor_by_symbol (self , symbol ):
467- warnings .warn (
468- "This function is deprecated, and will be removed in the 0.1.17 release. Use get_file_descriptors_by_symbol() instead." ,
469- DeprecationWarning ,
470- )
471- request = reflection_pb2 .ServerReflectionRequest (file_containing_symbol = symbol )
472- result = await self ._reflection_single_request (request )
473- proto = result .file_descriptor_response .file_descriptor_proto [0 ]
474- return descriptor_pb2 .FileDescriptorProto .FromString (proto )
475-
476513 async def get_file_descriptors_by_name (self , name ):
477514 request = reflection_pb2 .ServerReflectionRequest (file_by_filename = name )
478515 result = await self ._reflection_single_request (request )
0 commit comments