AberSheeran
Aber Sheeran

让一份代码同时支持同步与异步

起笔自
所属文集: 程序杂记
共计 5228 个字符
落笔于

在一般代码里,无论是同步还是异步,对于业务参数与返回结果的处理逻辑应当是一致的。有的时候我们为了同时支持同步和异步调用而不得不写两份代码,这会带来许多处理不一致的问题。单纯的把处理逻辑抽离出来作为独立函数,也不是特别能解决,在一个超大代码量的项目里其实很容易出现复制粘贴漏调用处理函数的情况。本文旨在提出一种思路,用于解决实际项目中不得不写同时维护 asyncsync 代码的问题。

众所周知,生成器(Generator)可以将一个完整函数,拆成多次调用执行。如果我们将 IO 部分剥离出来,把 IO 的参数和结果通过 yield 传递,那么同步和异步调用就可以共用同一个逻辑代码。

样例

import dataclasses
import os
import time
import typing
from typing import (
    Any,
    Awaitable,
    Callable,
    ClassVar,
    TypeVar,
    Generator,
    ParamSpec,
    Generic,
    Concatenate,
)

import httpx
import httpx._client
import httpx._types


class RemoteCall:
    _base_url: str
    _async_client: httpx.AsyncClient
    _sync_client: httpx.Client

    def __init__(self, *, base_url: str):
        self._base_url = base_url
        self.init_async_client()
        self.init_sync_client()

    def init_async_client(self):
        self._async_client = httpx.AsyncClient(base_url=self._base_url)

    def init_sync_client(self):
        self._sync_client = httpx.Client(base_url=self._base_url)

    async def __aenter__(self):
        if self._async_client.is_closed:
            self.init_async_client()
        return self

    async def __aexit__(self, exc_type, exc_value, traceback):
        await self._async_client.aclose()

    def __enter__(self):
        if self._sync_client.is_closed:
            self.init_sync_client()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self._sync_client.close()

    @staticmethod
    def _try_raise_http_exception(resp: httpx.Response) -> None:
        resp.raise_for_status()


P = ParamSpec("P")
R = TypeVar("R")


@dataclasses.dataclass
class IOCall(Generic[P, R]):
    _awaitable: Callable[Concatenate[RemoteCall, P], Awaitable[R]]
    _syncable: Callable[Concatenate[RemoteCall, P], R]
    this: RemoteCall

    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
        return self._syncable(self.this, *args, **kwargs)

    def awaitable(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[R]:
        return self._awaitable(self.this, *args, **kwargs)


class IOCallDescriptor(Generic[P, R]):
    def __init__(
        self,
        awaitable: Callable[Concatenate[RemoteCall, P], Awaitable[R]],
        syncable: Callable[Concatenate[RemoteCall, P], R],
    ):
        self.awaitable = awaitable
        self.syncable = syncable

    def __get__(self, instance: RemoteCall, owner: type[RemoteCall]) -> IOCall[P, R]:
        return IOCall(self.awaitable, self.syncable, instance)


@dataclasses.dataclass
class Request:
    method: str
    url: str

    content: httpx._types.RequestContent | None = None
    data: httpx._types.RequestData | None = None
    files: httpx._types.RequestFiles | None = None
    json: Any | None = None
    params: httpx._types.QueryParamTypes | None = None
    headers: httpx._types.HeaderTypes | None = None
    cookies: httpx._types.CookieTypes | None = None
    timeout: httpx._types.TimeoutTypes = None
    extensions: httpx._types.RequestExtensions | None = None


Response = httpx.Response


G = Generator[Request, Response, R]


def convert(
    func: Callable[Concatenate[typing.Any, P], Generator[Request, Response, R]],
) -> IOCallDescriptor[P, R]:
    async def async_wrapper(self: RemoteCall, *args: P.args, **kwargs: P.kwargs) -> R:
        g = func(self, *args, **kwargs)
        request = next(g)

        request = self._async_client.build_request(**dataclasses.asdict(request))
        resp = await self._async_client.send(request)
        self._try_raise_http_exception(resp)
        try:
            g.send(resp)
        except StopIteration as exc:
            return exc.value
        raise RuntimeError("Generator did not stop")

    def sync_wrapper(self: RemoteCall, *args: P.args, **kwargs: P.kwargs) -> R:
        g = func(self, *args, **kwargs)
        request = next(g)

        request = self._sync_client.build_request(**dataclasses.asdict(request))
        resp = self._sync_client.send(request)
        self._try_raise_http_exception(resp)
        try:
            g.send(resp)
        except StopIteration as exc:
            return exc.value
        raise RuntimeError("Generator did not stop")

    call = IOCallDescriptor(async_wrapper, sync_wrapper)
    return call

接下来看看怎么使用这些东西编写业务逻辑。同步调用就像调用一个普通的 method,而异步调用只需要用 .awaitable 追加在尾部。最妙的是,类型注释是可以被自动推导的,这里不会有任何参数或返回值的类型注释丢失。

from . import RemoteCall, convert, G, Request


class Session(RemoteCall):
    @convert
    def fetch_data(self, level: str) -> G[dict]:
        """
        fetch data
        """
        response = yield Request("POST", "/data", json={"level": level})
        return response.json()


session = Session(base_url="http://example.com")

r = session.fetch_data("xxxx")
# OR await
r = await session.fetch_data.awaitable("xxxx")

解析

convert 函数通过 send 把一个生成器 method 转为一个 IOCallDescriptor。而在 IOCallDescriptor__get__ 中把 instance 传递给 IOCall,又让我们自己创建的函数像一个普通 method 一样执行,并且能让 IOCall 读取到 RemoteCall 实例中的属性或方法加以使用。

Generic[P, R] 保证了所有的函数参数以及返回值被妥善的继承给 async_wrappersync_wrapper,让调用方无论通过同步还是异步调用都可以得到完整的类型推导帮助排除 BUG。

如果你觉得本文值得,不妨赏杯茶
在 1C 1G 的服务器上跑 mastodon
使用 zod 验证 Input file