diff --git a/conformance/test/client.py b/conformance/test/client.py index 7edadcd..f599886 100644 --- a/conformance/test/client.py +++ b/conformance/test/client.py @@ -45,6 +45,7 @@ from connectrpc.client import ResponseMetadata from connectrpc.code import Code +from connectrpc.codec import proto_json_codec from connectrpc.compression.brotli import BrotliCompression from connectrpc.compression.gzip import GzipCompression from connectrpc.compression.zstd import ZstdCompression @@ -173,7 +174,9 @@ async def client_sync( ZstdCompression(), ], send_compression=_convert_compression(test_request.compression), - proto_json=test_request.codec == Codec.CODEC_JSON, + codec=proto_json_codec() + if test_request.codec == Codec.CODEC_JSON + else None, protocol=protocol, read_max_bytes=read_max_bytes, ) as client, @@ -220,7 +223,9 @@ async def client_async( ZstdCompression(), ], send_compression=_convert_compression(test_request.compression), - proto_json=test_request.codec == Codec.CODEC_JSON, + codec=proto_json_codec() + if test_request.codec == Codec.CODEC_JSON + else None, protocol=protocol, read_max_bytes=read_max_bytes, ) as client, diff --git a/conformance/test/gen/connectrpc/conformance/v1/service_connect.py b/conformance/test/gen/connectrpc/conformance/v1/service_connect.py index b66bfcf..b95f3f7 100644 --- a/conformance/test/gen/connectrpc/conformance/v1/service_connect.py +++ b/conformance/test/gen/connectrpc/conformance/v1/service_connect.py @@ -26,6 +26,7 @@ Mapping, ) + from connectrpc.codec import Codec from connectrpc.compression import Compression from connectrpc.interceptor import Interceptor, InterceptorSync from connectrpc.request import Headers, RequestContext @@ -91,6 +92,7 @@ def __init__( interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, compressions: Iterable[Compression] | None = None, + codecs: Iterable[Codec] | None = None, ) -> None: super().__init__( service=service, @@ -159,6 +161,7 @@ def __init__( interceptors=interceptors, read_max_bytes=read_max_bytes, compressions=compressions, + codecs=codecs, ) @property @@ -358,6 +361,7 @@ def __init__( interceptors: Iterable[InterceptorSync] = (), read_max_bytes: int | None = None, compressions: Iterable[Compression] | None = None, + codecs: Iterable[Codec] | None = None, ) -> None: super().__init__( endpoints={ @@ -425,6 +429,7 @@ def __init__( interceptors=interceptors, read_max_bytes=read_max_bytes, compressions=compressions, + codecs=codecs, ) @property diff --git a/docs/api.md b/docs/api.md index eeb2327..643de07 100644 --- a/docs/api.md +++ b/docs/api.md @@ -11,6 +11,9 @@ ::: connectrpc.interceptor +::: connectrpc.codec +::: connectrpc.protocol + ::: connectrpc.compression ::: connectrpc.compression.brotli ::: connectrpc.compression.gzip diff --git a/example/example/eliza_connect.py b/example/example/eliza_connect.py index 6ab062e..54d39fc 100644 --- a/example/example/eliza_connect.py +++ b/example/example/eliza_connect.py @@ -26,6 +26,7 @@ Mapping, ) + from connectrpc.codec import Codec from connectrpc.compression import Compression from connectrpc.interceptor import Interceptor, InterceptorSync from connectrpc.request import Headers, RequestContext @@ -58,6 +59,7 @@ def __init__( interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, compressions: Iterable[Compression] | None = None, + codecs: Iterable[Codec] | None = None, ) -> None: super().__init__( service=service, @@ -96,6 +98,7 @@ def __init__( interceptors=interceptors, read_max_bytes=read_max_bytes, compressions=compressions, + codecs=codecs, ) @property @@ -194,6 +197,7 @@ def __init__( interceptors: Iterable[InterceptorSync] = (), read_max_bytes: int | None = None, compressions: Iterable[Compression] | None = None, + codecs: Iterable[Codec] | None = None, ) -> None: super().__init__( endpoints={ @@ -231,6 +235,7 @@ def __init__( interceptors=interceptors, read_max_bytes=read_max_bytes, compressions=compressions, + codecs=codecs, ) @property diff --git a/protoc-gen-connect-python/generator/template.go b/protoc-gen-connect-python/generator/template.go index 3f74e0b..8a8d32e 100644 --- a/protoc-gen-connect-python/generator/template.go +++ b/protoc-gen-connect-python/generator/template.go @@ -49,6 +49,7 @@ from typing import Protocol from connectrpc.client import ConnectClient, ConnectClientSync from connectrpc.code import Code +from connectrpc.codec import Codec from connectrpc.compression import Compression from connectrpc.errors import ConnectError from connectrpc.interceptor import Interceptor, InterceptorSync @@ -69,7 +70,7 @@ class {{.Name}}(Protocol):{{- range .Methods }} {{ end }} class {{.Name}}ASGIApplication(ConnectASGIApplication[{{.Name}}]): - def __init__(self, service: {{.Name}} | AsyncGenerator[{{.Name}}], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None, compressions: Iterable[Compression] | None = None) -> None: + def __init__(self, service: {{.Name}} | AsyncGenerator[{{.Name}}], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None, compressions: Iterable[Compression] | None = None, codecs: Iterable[Codec] | None = None) -> None: super().__init__( service=service, endpoints=lambda svc: { {{- range .Methods }} @@ -87,6 +88,7 @@ class {{.Name}}ASGIApplication(ConnectASGIApplication[{{.Name}}]): interceptors=interceptors, read_max_bytes=read_max_bytes, compressions=compressions, + codecs=codecs, ) @property @@ -130,7 +132,7 @@ class {{.Name}}Sync(Protocol):{{- range .Methods }} class {{.Name}}WSGIApplication(ConnectWSGIApplication): - def __init__(self, service: {{.Name}}Sync, interceptors: Iterable[InterceptorSync]=(), read_max_bytes: int | None = None, compressions: Iterable[Compression] | None = None) -> None: + def __init__(self, service: {{.Name}}Sync, interceptors: Iterable[InterceptorSync]=(), read_max_bytes: int | None = None, compressions: Iterable[Compression] | None = None, codecs: Iterable[Codec] | None = None) -> None: super().__init__( endpoints={ {{- range .Methods }} "/{{.ServiceName}}/{{.Name}}": EndpointSync.{{.EndpointType}}( @@ -147,6 +149,7 @@ class {{.Name}}WSGIApplication(ConnectWSGIApplication): interceptors=interceptors, read_max_bytes=read_max_bytes, compressions=compressions, + codecs=codecs, ) @property diff --git a/src/connectrpc/_client_async.py b/src/connectrpc/_client_async.py index 17caa3b..32c44a0 100644 --- a/src/connectrpc/_client_async.py +++ b/src/connectrpc/_client_async.py @@ -12,7 +12,7 @@ from . import _client_shared from ._asyncio_timeout import timeout as asyncio_timeout -from ._codec import Codec, get_proto_binary_codec, get_proto_json_codec +from ._codec import proto_binary_codec from ._compression import IdentityCompression, _gzip, resolve_compressions from ._interceptor_async import ( BidiStreamInterceptor, @@ -43,6 +43,7 @@ from types import TracebackType from ._envelope import EnvelopeReader + from .codec import Codec from .compression import Compression from .method import MethodInfo from .request import Headers, RequestContext @@ -92,7 +93,7 @@ def __init__( self, address: str, *, - proto_json: bool = False, + codec: Codec | None = None, protocol: ProtocolType = ProtocolType.CONNECT, accept_compression: Iterable[Compression] | None = None, send_compression: Compression | None = _gzip, @@ -105,7 +106,8 @@ def __init__( Args: address: The address of the server to connect to, including scheme. - proto_json: Whether to use JSON for the protocol. + codec: The [Codec][] to use for requests. If unset, defaults to binary protobuf. + For JSON encoding, use [proto_json_codec][connectrpc.codec.proto_json_codec]. protocol: The [ProtocolType][] to use for requests. accept_compression: Compression algorithms to accept from the server. If unset, defaults to gzip. If set to empty, disables response compression. @@ -117,7 +119,7 @@ def __init__( http_client: A pyqwest Client to use for requests. """ self._address = address - self._codec = get_proto_json_codec() if proto_json else get_proto_binary_codec() + self._codec = codec or proto_binary_codec() self._response_compressions = resolve_compressions(accept_compression) self._accept_compression_header = ",".join(self._response_compressions.keys()) self._send_compression = send_compression or IdentityCompression() diff --git a/src/connectrpc/_client_sync.py b/src/connectrpc/_client_sync.py index 7fa373e..fd4bbd7 100644 --- a/src/connectrpc/_client_sync.py +++ b/src/connectrpc/_client_sync.py @@ -10,7 +10,7 @@ from connectrpc._protocol_grpc import GRPCClientProtocol, GRPCWebClientProtocol from . import _client_shared -from ._codec import Codec, get_proto_binary_codec, get_proto_json_codec +from ._codec import proto_binary_codec from ._compression import IdentityCompression, _gzip, resolve_compressions from ._interceptor_sync import ( BidiStreamInterceptorSync, @@ -33,6 +33,7 @@ from types import TracebackType from ._envelope import EnvelopeReader + from .codec import Codec from .compression import Compression from .method import MethodInfo from .request import Headers, RequestContext @@ -82,7 +83,7 @@ def __init__( self, address: str, *, - proto_json: bool = False, + codec: Codec | None = None, protocol: ProtocolType = ProtocolType.CONNECT, accept_compression: Iterable[Compression] | None = None, send_compression: Compression | None = _gzip, @@ -95,7 +96,8 @@ def __init__( Args: address: The address of the server to connect to, including scheme. - proto_json: Whether to use JSON for the protocol. + codec: The [Codec][] to use for requests. If unset, defaults to binary protobuf. + For JSON encoding, use [proto_json_codec][connectrpc.codec.proto_json_codec]. protocol: The [ProtocolType][] to use for requests. accept_compression: Compression algorithms to accept from the server. If unset, defaults to gzip. If set to empty, disables response compression. @@ -107,7 +109,7 @@ def __init__( http_client: A pyqwest SyncClient to use for requests. """ self._address = address - self._codec = get_proto_json_codec() if proto_json else get_proto_binary_codec() + self._codec = codec or proto_binary_codec() self._timeout_ms = timeout_ms self._read_max_bytes = read_max_bytes self._response_compressions = resolve_compressions(accept_compression) diff --git a/src/connectrpc/_codec.py b/src/connectrpc/_codec.py index 10167cb..06190f4 100644 --- a/src/connectrpc/_codec.py +++ b/src/connectrpc/_codec.py @@ -20,7 +20,10 @@ class Codec(Protocol[T_contra, U]): def name(self) -> str: - """Returns the name of the codec.""" + """Returns the name of the codec. + + This corresponds to the content-type used in requests. + """ ... def encode(self, message: T_contra) -> bytes: @@ -49,8 +52,11 @@ def decode(self, data: bytes | bytearray, message: V) -> V: class ProtoJSONCodec(Codec[Message, V]): """Codec for Protocol bytes | bytearrays JSON format.""" + def __init__(self, name: str = "json") -> None: + self._name = name + def name(self) -> str: - return "json" + return self._name def encode(self, message: Message) -> bytes: return MessageToJson(message).encode() @@ -60,27 +66,24 @@ def decode(self, data: bytes | bytearray, message: V) -> V: return message -# TODO: Codecs can generally be customized per handler instead of as a global -# registry, though the usage isn't common. _proto_binary_codec = ProtoBinaryCodec() _proto_json_codec = ProtoJSONCodec() -_codecs = { - CODEC_NAME_PROTO: _proto_binary_codec, - CODEC_NAME_JSON: _proto_json_codec, - CODEC_NAME_JSON_CHARSET_UTF8: _proto_json_codec, -} +_default_codecs = [ + _proto_binary_codec, + _proto_json_codec, + ProtoJSONCodec(name=CODEC_NAME_JSON_CHARSET_UTF8), +] -def get_proto_binary_codec() -> Codec: - """Returns the Protocol bytes | bytearrays binary codec.""" - return _proto_binary_codec +def get_default_codecs() -> list[Codec]: + return _default_codecs -def get_proto_json_codec() -> Codec: - """Returns the Protocol bytes | bytearrays JSON codec.""" - return _proto_json_codec +def proto_binary_codec() -> Codec: + """Returns the Protocol Buffers binary codec.""" + return _proto_binary_codec -def get_codec(name: str) -> Codec | None: - """Returns the codec with the given name.""" - return _codecs.get(name) +def proto_json_codec() -> Codec: + """Returns the Protocol Buffers JSON codec.""" + return _proto_json_codec diff --git a/src/connectrpc/_server_async.py b/src/connectrpc/_server_async.py index e20d0d8..e9a8b62 100644 --- a/src/connectrpc/_server_async.py +++ b/src/connectrpc/_server_async.py @@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Generic, TypeVar, cast from urllib.parse import parse_qs -from ._codec import Codec, get_codec +from ._codec import Codec, get_default_codecs from ._compression import negotiate_compression, resolve_compressions from ._envelope import EnvelopeReader from ._interceptor_async import ( @@ -91,6 +91,7 @@ def __init__( interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, compressions: Iterable[Compression] | None = None, + codecs: Iterable[Codec] | None = None, ) -> None: """Initialize the ASGI application. @@ -103,6 +104,8 @@ def __init__( read_max_bytes: Maximum size of request messages. compressions: Supported compression algorithms. If unset, defaults to gzip. If set to empty, disables compression. + codecs: The codecs supported by the server. If unset, defaults to Protocol Buffers + binary and JSON codecs. """ super().__init__() self._service = service @@ -111,6 +114,8 @@ def __init__( self._resolved_endpoints = None self._read_max_bytes = read_max_bytes self._compressions = resolve_compressions(compressions) + codecs = codecs if codecs is not None else get_default_codecs() + self._codecs = {codec.name(): codec for codec in codecs} async def __call__( self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable @@ -208,7 +213,7 @@ async def __call__( codec_name = protocol.codec_name_from_content_type( headers.get("content-type", ""), stream=not is_unary ) - codec = get_codec(codec_name.lower()) + codec = self._codecs.get(codec_name.lower()) if not codec: raise HTTPException( HTTPStatus.UNSUPPORTED_MEDIA_TYPE, diff --git a/src/connectrpc/_server_sync.py b/src/connectrpc/_server_sync.py index 3182b1e..83b3896 100644 --- a/src/connectrpc/_server_sync.py +++ b/src/connectrpc/_server_sync.py @@ -9,7 +9,7 @@ from urllib.parse import parse_qs from . import _server_shared -from ._codec import Codec, get_codec +from ._codec import Codec, get_default_codecs from ._compression import negotiate_compression, resolve_compressions from ._envelope import EnvelopeReader, EnvelopeWriter from ._interceptor_sync import ( @@ -168,6 +168,7 @@ def __init__( interceptors: Iterable[InterceptorSync] = (), read_max_bytes: int | None = None, compressions: Iterable[Compression] | None = None, + codecs: Iterable[Codec] | None = None, ) -> None: """Initialize the WSGI application. @@ -178,6 +179,8 @@ def __init__( read_max_bytes: Maximum size of request messages. compressions: Supported compression algorithms. If unset, defaults to gzip. If set to empty, disables compression. + codecs: The codecs supported by the server. If unset, defaults to Protocol Buffers + binary and JSON codecs. """ super().__init__() if interceptors: @@ -194,6 +197,8 @@ def __init__( self._endpoints = endpoints self._read_max_bytes = read_max_bytes self._compressions = resolve_compressions(compressions) + codecs = codecs if codecs is not None else get_default_codecs() + self._codecs = {codec.name(): codec for codec in codecs} def __call__( self, environ: WSGIEnvironment, start_response: StartResponse @@ -299,7 +304,7 @@ def _handle_post_request( codec_name = codec_name_from_content_type( request_headers.get("content-type", ""), stream=False ) - codec = get_codec(codec_name) + codec = self._codecs.get(codec_name) if not codec: raise HTTPException( HTTPStatus.UNSUPPORTED_MEDIA_TYPE, @@ -393,7 +398,7 @@ def _handle_get_request( message = compression.decompress(message) codec_name = params.get("encoding", ("",))[0] - codec = get_codec(codec_name) + codec = self._codecs.get(codec_name) if not codec: raise ConnectError( Code.UNIMPLEMENTED, f"invalid message encoding: '{codec_name}'" @@ -430,7 +435,7 @@ def _handle_stream( codec_name = protocol.codec_name_from_content_type( headers.get("content-type", ""), stream=True ) - codec = get_codec(codec_name) + codec = self._codecs.get(codec_name) if not codec: raise HTTPException( HTTPStatus.UNSUPPORTED_MEDIA_TYPE, diff --git a/src/connectrpc/codec.py b/src/connectrpc/codec.py new file mode 100644 index 0000000..444057c --- /dev/null +++ b/src/connectrpc/codec.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +__all__ = ["Codec", "proto_binary_codec", "proto_json_codec"] + +from ._codec import Codec, proto_binary_codec, proto_json_codec diff --git a/test/haberdasher_connect.py b/test/haberdasher_connect.py index 541420c..1d4f79a 100644 --- a/test/haberdasher_connect.py +++ b/test/haberdasher_connect.py @@ -28,6 +28,7 @@ Mapping, ) + from connectrpc.codec import Codec from connectrpc.compression import Compression from connectrpc.interceptor import Interceptor, InterceptorSync from connectrpc.request import Headers, RequestContext @@ -73,6 +74,7 @@ def __init__( interceptors: Iterable[Interceptor] = (), read_max_bytes: int | None = None, compressions: Iterable[Compression] | None = None, + codecs: Iterable[Codec] | None = None, ) -> None: super().__init__( service=service, @@ -141,6 +143,7 @@ def __init__( interceptors=interceptors, read_max_bytes=read_max_bytes, compressions=compressions, + codecs=codecs, ) @property @@ -312,6 +315,7 @@ def __init__( interceptors: Iterable[InterceptorSync] = (), read_max_bytes: int | None = None, compressions: Iterable[Compression] | None = None, + codecs: Iterable[Codec] | None = None, ) -> None: super().__init__( endpoints={ @@ -379,6 +383,7 @@ def __init__( interceptors=interceptors, read_max_bytes=read_max_bytes, compressions=compressions, + codecs=codecs, ) @property diff --git a/test/test_codec.py b/test/test_codec.py new file mode 100644 index 0000000..9645a86 --- /dev/null +++ b/test/test_codec.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import pytest +from google.protobuf.message import Message +from pyqwest import ( + Client, + Request, + Response, + SyncClient, + SyncRequest, + SyncResponse, + SyncTransport, + Transport, +) +from pyqwest.testing import ASGITransport, WSGITransport + +from connectrpc.codec import Codec + +from .haberdasher_connect import ( + Haberdasher, + HaberdasherASGIApplication, + HaberdasherClient, + HaberdasherClientSync, + HaberdasherSync, + HaberdasherWSGIApplication, +) +from .haberdasher_pb2 import Hat, Size + + +class CustomCodec(Codec[Message, Message]): + def name(self) -> str: + return "proto" + + def encode(self, message: Message) -> bytes: + match message: + case Size(inches=inches): + return f"{inches}".encode() + case Hat(size=size, color=color): + return f"{size}:{color}".encode() + case _: + raise ValueError(f"unexpected message type: {type(message)}") + + def decode(self, data: bytes | bytearray, message: Message) -> Message: + s = data.decode() + match message: + case Size(): + message.inches = int(s) + case Hat(): + size, color = s.split(":") + message.size = int(size) + message.color = color + return message + + +class SimpleHaberdasher(Haberdasher): + async def make_hat(self, request: Size, ctx): + return Hat(size=request.inches, color="blue") + + +class SimpleHabersahserSync(HaberdasherSync): + def make_hat(self, request: Size, ctx): + return Hat(size=request.inches, color="blue") + + +@pytest.mark.asyncio +async def test_custom_codec() -> None: + logged_content: bytes = b"" + + class LoggingTransport(Transport): + def __init__(self, transport: Transport) -> None: + self._transport = transport + self.last_request_data: bytes | None = None + + async def execute(self, request: Request) -> Response: + chunks = [] + async for chunk in request.content: + chunks.append(chunk) + nonlocal logged_content + logged_content = b"".join(chunks) + return await self._transport.execute( + Request( + method=request.method, + url=request.url, + headers=request.headers, + content=logged_content, + ) + ) + + transport = LoggingTransport( + ASGITransport( + HaberdasherASGIApplication(SimpleHaberdasher(), codecs=[CustomCodec()]) + ) + ) + client = HaberdasherClient( + "http://localhost", + http_client=Client(transport), + codec=CustomCodec(), + send_compression=None, + ) + + res = await client.make_hat(Size(inches=10)) + assert res.size == 10 + assert res.color == "blue" + # Should be enough to just log/assert the client side + assert logged_content == b"10" + + +def test_custom_codec_sync() -> None: + logged_content: bytes = b"" + + class LoggingSyncTransport(SyncTransport): + def __init__(self, transport: SyncTransport) -> None: + self._transport = transport + + def execute_sync(self, request: SyncRequest) -> SyncResponse: + nonlocal logged_content + logged_content = b"".join(request.content) + return self._transport.execute_sync( + SyncRequest( + method=request.method, + url=request.url, + headers=request.headers, + content=logged_content, + ) + ) + + transport = LoggingSyncTransport( + WSGITransport( + HaberdasherWSGIApplication(SimpleHabersahserSync(), codecs=[CustomCodec()]) + ) + ) + client = HaberdasherClientSync( + "http://localhost", + http_client=SyncClient(transport=transport), + codec=CustomCodec(), + send_compression=None, + ) + + res = client.make_hat(Size(inches=10)) + assert res.size == 10 + assert res.color == "blue" + # Should be enough to just log/assert the client side + assert logged_content == b"10" diff --git a/test/test_roundtrip.py b/test/test_roundtrip.py index 406d07c..7294346 100644 --- a/test/test_roundtrip.py +++ b/test/test_roundtrip.py @@ -9,6 +9,7 @@ from pyqwest.testing import ASGITransport, WSGITransport from connectrpc.code import Code +from connectrpc.codec import proto_json_codec from connectrpc.errors import ConnectError from ._util import resolve_compression @@ -42,7 +43,7 @@ def make_hat(self, request, ctx): with HaberdasherClientSync( "http://localhost", http_client=SyncClient(WSGITransport(app=app)), - proto_json=proto_json, + codec=proto_json_codec() if proto_json else None, send_compression=compression, accept_compression=[compression], ) as client: @@ -65,7 +66,7 @@ async def make_hat(self, request, ctx): async with HaberdasherClient( "http://localhost", http_client=Client(transport), - proto_json=proto_json, + codec=proto_json_codec() if proto_json else None, send_compression=compression, accept_compression=[compression], ) as client: @@ -135,7 +136,7 @@ async def make_similar_hats(self, request, ctx): async with HaberdasherClient( "http://localhost", http_client=Client(transport=transport), - proto_json=proto_json, + codec=proto_json_codec() if proto_json else None, send_compression=compression, accept_compression=[compression], ) as client: