在一般代码里,无论是同步还是异步,对于业务参数与返回结果的处理逻辑应当是一致的。有的时候我们为了同时支持同步和异步调用而不得不写两份代码,这会带来许多处理不一致的问题。单纯的把处理逻辑抽离出来作为独立函数,也不是特别能解决,在一个超大代码量的项目里其实很容易出现复制粘贴漏调用处理函数的情况。本文旨在提出一种思路,用于解决实际项目中不得不写同时维护 async
和 sync
代码的问题。
众所周知,生成器(Generator
)可以将一个完整函数,拆成多次调用执行。如果我们将 IO 部分剥离出来,把 IO 的参数和结果通过 yield
传递,那么同步和异步调用就可以共用同一个逻辑代码。
样例
import dataclasses
import os
import time
import typing
from typing import (
Any,
Awaitable,
Callable,
ClassVar,
TypeVar,
Generator,
ParamSpec,
Generic,
Concatenate,
)
import httpx
import httpx._client
import httpx._types
class RemoteCall:
_base_url: str
_async_client: httpx.AsyncClient
_sync_client: httpx.Client
def __init__(self, *, base_url: str):
self._base_url = base_url
self.init_async_client()
self.init_sync_client()
def init_async_client(self):
self._async_client = httpx.AsyncClient(base_url=self._base_url)
def init_sync_client(self):
self._sync_client = httpx.Client(base_url=self._base_url)
async def __aenter__(self):
if self._async_client.is_closed:
self.init_async_client()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self._async_client.aclose()
def __enter__(self):
if self._sync_client.is_closed:
self.init_sync_client()
return self
def __exit__(self, exc_type, exc_value, traceback):
self._sync_client.close()
@staticmethod
def _try_raise_http_exception(resp: httpx.Response) -> None:
resp.raise_for_status()
P = ParamSpec("P")
R = TypeVar("R")
@dataclasses.dataclass
class IOCall(Generic[P, R]):
_awaitable: Callable[Concatenate[RemoteCall, P], Awaitable[R]]
_syncable: Callable[Concatenate[RemoteCall, P], R]
this: RemoteCall
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
return self._syncable(self.this, *args, **kwargs)
def awaitable(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[R]:
return self._awaitable(self.this, *args, **kwargs)
class IOCallDescriptor(Generic[P, R]):
def __init__(
self,
awaitable: Callable[Concatenate[RemoteCall, P], Awaitable[R]],
syncable: Callable[Concatenate[RemoteCall, P], R],
):
self.awaitable = awaitable
self.syncable = syncable
def __get__(self, instance: RemoteCall, owner: type[RemoteCall]) -> IOCall[P, R]:
return IOCall(self.awaitable, self.syncable, instance)
@dataclasses.dataclass
class Request:
method: str
url: str
content: httpx._types.RequestContent | None = None
data: httpx._types.RequestData | None = None
files: httpx._types.RequestFiles | None = None
json: Any | None = None
params: httpx._types.QueryParamTypes | None = None
headers: httpx._types.HeaderTypes | None = None
cookies: httpx._types.CookieTypes | None = None
timeout: httpx._types.TimeoutTypes = None
extensions: httpx._types.RequestExtensions | None = None
Response = httpx.Response
G = Generator[Request, Response, R]
def convert(
func: Callable[Concatenate[typing.Any, P], Generator[Request, Response, R]],
) -> IOCallDescriptor[P, R]:
async def async_wrapper(self: RemoteCall, *args: P.args, **kwargs: P.kwargs) -> R:
g = func(self, *args, **kwargs)
request = next(g)
request = self._async_client.build_request(**dataclasses.asdict(request))
resp = await self._async_client.send(request)
self._try_raise_http_exception(resp)
try:
g.send(resp)
except StopIteration as exc:
return exc.value
raise RuntimeError("Generator did not stop")
def sync_wrapper(self: RemoteCall, *args: P.args, **kwargs: P.kwargs) -> R:
g = func(self, *args, **kwargs)
request = next(g)
request = self._sync_client.build_request(**dataclasses.asdict(request))
resp = self._sync_client.send(request)
self._try_raise_http_exception(resp)
try:
g.send(resp)
except StopIteration as exc:
return exc.value
raise RuntimeError("Generator did not stop")
call = IOCallDescriptor(async_wrapper, sync_wrapper)
return call
接下来看看怎么使用这些东西编写业务逻辑。同步调用就像调用一个普通的 method
,而异步调用只需要用 .awaitable
追加在尾部。最妙的是,类型注释是可以被自动推导的,这里不会有任何参数或返回值的类型注释丢失。
from . import RemoteCall, convert, G, Request
class Session(RemoteCall):
@convert
def fetch_data(self, level: str) -> G[dict]:
"""
fetch data
"""
response = yield Request("POST", "/data", json={"level": level})
return response.json()
session = Session(base_url="http://example.com")
r = session.fetch_data("xxxx")
# OR await
r = await session.fetch_data.awaitable("xxxx")
解析
convert
函数通过 send
把一个生成器 method
转为一个 IOCallDescriptor
。而在 IOCallDescriptor
的 __get__
中把 instance
传递给 IOCall
,又让我们自己创建的函数像一个普通 method
一样执行,并且能让 IOCall
读取到 RemoteCall
实例中的属性或方法加以使用。
Generic[P, R]
保证了所有的函数参数以及返回值被妥善的继承给 async_wrapper
和 sync_wrapper
,让调用方无论通过同步还是异步调用都可以得到完整的类型推导帮助排除 BUG。