Skip to content

Commit 9c14bba

Browse files
committed
Allow customization of server and client codecs
Signed-off-by: Anuraag Agrawal <anuraaga@gmail.com>
1 parent a1d1f63 commit 9c14bba

File tree

14 files changed

+230
-39
lines changed

14 files changed

+230
-39
lines changed

conformance/test/client.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
from connectrpc.client import ResponseMetadata
4747
from connectrpc.code import Code
48+
from connectrpc.codec import proto_json_codec
4849
from connectrpc.compression.brotli import BrotliCompression
4950
from connectrpc.compression.gzip import GzipCompression
5051
from connectrpc.compression.zstd import ZstdCompression
@@ -173,7 +174,9 @@ async def client_sync(
173174
ZstdCompression(),
174175
],
175176
send_compression=_convert_compression(test_request.compression),
176-
proto_json=test_request.codec == Codec.CODEC_JSON,
177+
codec=proto_json_codec()
178+
if test_request.codec == Codec.CODEC_JSON
179+
else None,
177180
protocol=protocol,
178181
read_max_bytes=read_max_bytes,
179182
) as client,
@@ -220,7 +223,9 @@ async def client_async(
220223
ZstdCompression(),
221224
],
222225
send_compression=_convert_compression(test_request.compression),
223-
proto_json=test_request.codec == Codec.CODEC_JSON,
226+
codec=proto_json_codec()
227+
if test_request.codec == Codec.CODEC_JSON
228+
else None,
224229
protocol=protocol,
225230
read_max_bytes=read_max_bytes,
226231
) as client,

conformance/test/gen/connectrpc/conformance/v1/service_connect.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Mapping,
2727
)
2828

29+
from connectrpc.codec import Codec
2930
from connectrpc.compression import Compression
3031
from connectrpc.interceptor import Interceptor, InterceptorSync
3132
from connectrpc.request import Headers, RequestContext
@@ -91,6 +92,7 @@ def __init__(
9192
interceptors: Iterable[Interceptor] = (),
9293
read_max_bytes: int | None = None,
9394
compressions: Iterable[Compression] | None = None,
95+
codecs: Iterable[Codec] | None = None,
9496
) -> None:
9597
super().__init__(
9698
service=service,
@@ -159,6 +161,7 @@ def __init__(
159161
interceptors=interceptors,
160162
read_max_bytes=read_max_bytes,
161163
compressions=compressions,
164+
codecs=codecs,
162165
)
163166

164167
@property
@@ -358,6 +361,7 @@ def __init__(
358361
interceptors: Iterable[InterceptorSync] = (),
359362
read_max_bytes: int | None = None,
360363
compressions: Iterable[Compression] | None = None,
364+
codecs: Iterable[Codec] | None = None,
361365
) -> None:
362366
super().__init__(
363367
endpoints={
@@ -425,6 +429,7 @@ def __init__(
425429
interceptors=interceptors,
426430
read_max_bytes=read_max_bytes,
427431
compressions=compressions,
432+
codecs=codecs,
428433
)
429434

430435
@property

docs/api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
::: connectrpc.interceptor
1313

14+
::: connectrpc.codec
15+
1416
::: connectrpc.compression
1517
::: connectrpc.compression.brotli
1618
::: connectrpc.compression.gzip

example/example/eliza_connect.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Mapping,
2727
)
2828

29+
from connectrpc.codec import Codec
2930
from connectrpc.compression import Compression
3031
from connectrpc.interceptor import Interceptor, InterceptorSync
3132
from connectrpc.request import Headers, RequestContext
@@ -58,6 +59,7 @@ def __init__(
5859
interceptors: Iterable[Interceptor] = (),
5960
read_max_bytes: int | None = None,
6061
compressions: Iterable[Compression] | None = None,
62+
codecs: Iterable[Codec] | None = None,
6163
) -> None:
6264
super().__init__(
6365
service=service,
@@ -96,6 +98,7 @@ def __init__(
9698
interceptors=interceptors,
9799
read_max_bytes=read_max_bytes,
98100
compressions=compressions,
101+
codecs=codecs,
99102
)
100103

101104
@property
@@ -194,6 +197,7 @@ def __init__(
194197
interceptors: Iterable[InterceptorSync] = (),
195198
read_max_bytes: int | None = None,
196199
compressions: Iterable[Compression] | None = None,
200+
codecs: Iterable[Codec] | None = None,
197201
) -> None:
198202
super().__init__(
199203
endpoints={
@@ -231,6 +235,7 @@ def __init__(
231235
interceptors=interceptors,
232236
read_max_bytes=read_max_bytes,
233237
compressions=compressions,
238+
codecs=codecs,
234239
)
235240

236241
@property

protoc-gen-connect-python/generator/template.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ from typing import Protocol
4949
5050
from connectrpc.client import ConnectClient, ConnectClientSync
5151
from connectrpc.code import Code
52+
from connectrpc.codec import Codec
5253
from connectrpc.compression import Compression
5354
from connectrpc.errors import ConnectError
5455
from connectrpc.interceptor import Interceptor, InterceptorSync
@@ -69,7 +70,7 @@ class {{.Name}}(Protocol):{{- range .Methods }}
6970
{{ end }}
7071
7172
class {{.Name}}ASGIApplication(ConnectASGIApplication[{{.Name}}]):
72-
def __init__(self, service: {{.Name}} | AsyncGenerator[{{.Name}}], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None, compressions: Iterable[Compression] | None = None) -> None:
73+
def __init__(self, service: {{.Name}} | AsyncGenerator[{{.Name}}], *, interceptors: Iterable[Interceptor]=(), read_max_bytes: int | None = None, compressions: Iterable[Compression] | None = None, codecs: Iterable[Codec] | None = None) -> None:
7374
super().__init__(
7475
service=service,
7576
endpoints=lambda svc: { {{- range .Methods }}
@@ -87,6 +88,7 @@ class {{.Name}}ASGIApplication(ConnectASGIApplication[{{.Name}}]):
8788
interceptors=interceptors,
8889
read_max_bytes=read_max_bytes,
8990
compressions=compressions,
91+
codecs=codecs,
9092
)
9193
9294
@property
@@ -130,7 +132,7 @@ class {{.Name}}Sync(Protocol):{{- range .Methods }}
130132
131133
132134
class {{.Name}}WSGIApplication(ConnectWSGIApplication):
133-
def __init__(self, service: {{.Name}}Sync, interceptors: Iterable[InterceptorSync]=(), read_max_bytes: int | None = None, compressions: Iterable[Compression] | None = None) -> None:
135+
def __init__(self, service: {{.Name}}Sync, interceptors: Iterable[InterceptorSync]=(), read_max_bytes: int | None = None, compressions: Iterable[Compression] | None = None, codecs: Iterable[Codec] | None = None) -> None:
134136
super().__init__(
135137
endpoints={ {{- range .Methods }}
136138
"/{{.ServiceName}}/{{.Name}}": EndpointSync.{{.EndpointType}}(
@@ -147,6 +149,7 @@ class {{.Name}}WSGIApplication(ConnectWSGIApplication):
147149
interceptors=interceptors,
148150
read_max_bytes=read_max_bytes,
149151
compressions=compressions,
152+
codecs=codecs,
150153
)
151154
152155
@property

src/connectrpc/_client_async.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from . import _client_shared
1414
from ._asyncio_timeout import timeout as asyncio_timeout
15-
from ._codec import Codec, get_proto_binary_codec, get_proto_json_codec
15+
from ._codec import proto_binary_codec
1616
from ._compression import IdentityCompression, _gzip, resolve_compressions
1717
from ._interceptor_async import (
1818
BidiStreamInterceptor,
@@ -43,6 +43,7 @@
4343
from types import TracebackType
4444

4545
from ._envelope import EnvelopeReader
46+
from .codec import Codec
4647
from .compression import Compression
4748
from .method import MethodInfo
4849
from .request import Headers, RequestContext
@@ -92,7 +93,7 @@ def __init__(
9293
self,
9394
address: str,
9495
*,
95-
proto_json: bool = False,
96+
codec: Codec | None = None,
9697
protocol: ProtocolType = ProtocolType.CONNECT,
9798
accept_compression: Iterable[Compression] | None = None,
9899
send_compression: Compression | None = _gzip,
@@ -105,7 +106,8 @@ def __init__(
105106
106107
Args:
107108
address: The address of the server to connect to, including scheme.
108-
proto_json: Whether to use JSON for the protocol.
109+
codec: The [Codec][] to use for requests. If unset, defaults to binary protobuf.
110+
For JSON encoding, use [proto_json_codec][connectrpc.codec.proto_json_codec].
109111
protocol: The [ProtocolType][] to use for requests.
110112
accept_compression: Compression algorithms to accept from the server. If unset,
111113
defaults to gzip. If set to empty, disables response compression.
@@ -117,7 +119,7 @@ def __init__(
117119
http_client: A pyqwest Client to use for requests.
118120
"""
119121
self._address = address
120-
self._codec = get_proto_json_codec() if proto_json else get_proto_binary_codec()
122+
self._codec = codec or proto_binary_codec()
121123
self._response_compressions = resolve_compressions(accept_compression)
122124
self._accept_compression_header = ",".join(self._response_compressions.keys())
123125
self._send_compression = send_compression or IdentityCompression()

src/connectrpc/_client_sync.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from connectrpc._protocol_grpc import GRPCClientProtocol, GRPCWebClientProtocol
1111

1212
from . import _client_shared
13-
from ._codec import Codec, get_proto_binary_codec, get_proto_json_codec
13+
from ._codec import proto_binary_codec
1414
from ._compression import IdentityCompression, _gzip, resolve_compressions
1515
from ._interceptor_sync import (
1616
BidiStreamInterceptorSync,
@@ -33,6 +33,7 @@
3333
from types import TracebackType
3434

3535
from ._envelope import EnvelopeReader
36+
from .codec import Codec
3637
from .compression import Compression
3738
from .method import MethodInfo
3839
from .request import Headers, RequestContext
@@ -82,7 +83,7 @@ def __init__(
8283
self,
8384
address: str,
8485
*,
85-
proto_json: bool = False,
86+
codec: Codec | None = None,
8687
protocol: ProtocolType = ProtocolType.CONNECT,
8788
accept_compression: Iterable[Compression] | None = None,
8889
send_compression: Compression | None = _gzip,
@@ -95,7 +96,8 @@ def __init__(
9596
9697
Args:
9798
address: The address of the server to connect to, including scheme.
98-
proto_json: Whether to use JSON for the protocol.
99+
codec: The [Codec][] to use for requests. If unset, defaults to binary protobuf.
100+
For JSON encoding, use [proto_json_codec][connectrpc.codec.proto_json_codec].
99101
protocol: The [ProtocolType][] to use for requests.
100102
accept_compression: Compression algorithms to accept from the server. If unset,
101103
defaults to gzip. If set to empty, disables response compression.
@@ -107,7 +109,7 @@ def __init__(
107109
http_client: A pyqwest SyncClient to use for requests.
108110
"""
109111
self._address = address
110-
self._codec = get_proto_json_codec() if proto_json else get_proto_binary_codec()
112+
self._codec = codec or proto_binary_codec()
111113
self._timeout_ms = timeout_ms
112114
self._read_max_bytes = read_max_bytes
113115
self._response_compressions = resolve_compressions(accept_compression)

src/connectrpc/_codec.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020

2121
class Codec(Protocol[T_contra, U]):
2222
def name(self) -> str:
23-
"""Returns the name of the codec."""
23+
"""Returns the name of the codec.
24+
25+
This corresponds to the content-type used in requests.
26+
"""
2427
...
2528

2629
def encode(self, message: T_contra) -> bytes:
@@ -49,8 +52,11 @@ def decode(self, data: bytes | bytearray, message: V) -> V:
4952
class ProtoJSONCodec(Codec[Message, V]):
5053
"""Codec for Protocol bytes | bytearrays JSON format."""
5154

55+
def __init__(self, name: str = "json") -> None:
56+
self._name = name
57+
5258
def name(self) -> str:
53-
return "json"
59+
return self._name
5460

5561
def encode(self, message: Message) -> bytes:
5662
return MessageToJson(message).encode()
@@ -60,27 +66,24 @@ def decode(self, data: bytes | bytearray, message: V) -> V:
6066
return message
6167

6268

63-
# TODO: Codecs can generally be customized per handler instead of as a global
64-
# registry, though the usage isn't common.
6569
_proto_binary_codec = ProtoBinaryCodec()
6670
_proto_json_codec = ProtoJSONCodec()
67-
_codecs = {
68-
CODEC_NAME_PROTO: _proto_binary_codec,
69-
CODEC_NAME_JSON: _proto_json_codec,
70-
CODEC_NAME_JSON_CHARSET_UTF8: _proto_json_codec,
71-
}
71+
_default_codecs = [
72+
_proto_binary_codec,
73+
_proto_json_codec,
74+
ProtoJSONCodec(name=CODEC_NAME_JSON_CHARSET_UTF8),
75+
]
7276

7377

74-
def get_proto_binary_codec() -> Codec:
75-
"""Returns the Protocol bytes | bytearrays binary codec."""
76-
return _proto_binary_codec
78+
def get_default_codecs() -> list[Codec]:
79+
return _default_codecs
7780

7881

79-
def get_proto_json_codec() -> Codec:
80-
"""Returns the Protocol bytes | bytearrays JSON codec."""
81-
return _proto_json_codec
82+
def proto_binary_codec() -> Codec:
83+
"""Returns the Protocol Buffers binary codec."""
84+
return _proto_binary_codec
8285

8386

84-
def get_codec(name: str) -> Codec | None:
85-
"""Returns the codec with the given name."""
86-
return _codecs.get(name)
87+
def proto_json_codec() -> Codec:
88+
"""Returns the Protocol Buffers JSON codec."""
89+
return _proto_json_codec

src/connectrpc/_server_async.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from typing import TYPE_CHECKING, Generic, TypeVar, cast
1212
from urllib.parse import parse_qs
1313

14-
from ._codec import Codec, get_codec
14+
from ._codec import Codec, get_default_codecs
1515
from ._compression import negotiate_compression, resolve_compressions
1616
from ._envelope import EnvelopeReader
1717
from ._interceptor_async import (
@@ -91,6 +91,7 @@ def __init__(
9191
interceptors: Iterable[Interceptor] = (),
9292
read_max_bytes: int | None = None,
9393
compressions: Iterable[Compression] | None = None,
94+
codecs: Iterable[Codec] | None = None,
9495
) -> None:
9596
"""Initialize the ASGI application.
9697
@@ -103,6 +104,8 @@ def __init__(
103104
read_max_bytes: Maximum size of request messages.
104105
compressions: Supported compression algorithms. If unset, defaults to gzip.
105106
If set to empty, disables compression.
107+
codecs: The codecs supported by the server. If unset, defaults to Protocol Buffers
108+
binary and JSON codecs.
106109
"""
107110
super().__init__()
108111
self._service = service
@@ -111,6 +114,8 @@ def __init__(
111114
self._resolved_endpoints = None
112115
self._read_max_bytes = read_max_bytes
113116
self._compressions = resolve_compressions(compressions)
117+
codecs = codecs if codecs is not None else get_default_codecs()
118+
self._codecs = {codec.name(): codec for codec in codecs}
114119

115120
async def __call__(
116121
self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
@@ -208,7 +213,7 @@ async def __call__(
208213
codec_name = protocol.codec_name_from_content_type(
209214
headers.get("content-type", ""), stream=not is_unary
210215
)
211-
codec = get_codec(codec_name.lower())
216+
codec = self._codecs.get(codec_name.lower())
212217
if not codec:
213218
raise HTTPException(
214219
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,

0 commit comments

Comments
 (0)