AberSheeran
Aber Sheeran

让一份代码同时支持同步与异步

起笔自
所属文集: 程序杂记
共计 4980 个字符
落笔于

在一般代码里,无论是同步还是异步,对于业务参数与返回结果的处理逻辑应当是一致的。有的时候我们为了同时支持同步和异步调用而不得不写两份代码,这会带来许多处理不一致的问题。单纯的把处理逻辑抽离出来作为独立函数,也不是特别能解决,在一个超大代码量的项目里其实很容易出现复制粘贴漏调用处理函数的情况。本文旨在提出一种思路,用于解决实际项目中不得不写同时维护 asyncsync 代码的问题。

众所周知,生成器(Generator)可以将一个完整函数,拆成多次调用执行。如果我们将 IO 部分剥离出来,把 IO 的参数和结果通过 yield 传递,那么同步和异步调用就可以共用同一个逻辑代码。以下是一个完整可用的样例。

import dataclasses
import os
import time
import typing
from typing import (
    Any,
    Awaitable,
    Callable,
    ClassVar,
    TypeVar,
    Generator,
    ParamSpec,
    Generic,
    Concatenate,
)

import httpx
import httpx._client
import httpx._types
from loguru import logger
from baize.exceptions import HTTPException


class RemoteCall:
    _env: str
    _async_client: httpx.AsyncClient
    _sync_client: httpx.Client

    @staticmethod
    def get_env(env: str) -> str:
        base_url = os.environ.get(env)
        if not base_url:
            raise RuntimeError(f"env {env} is not set")
        return base_url.rstrip("/")

    @classmethod
    def init_async_client(cls):
        cls._async_client = httpx.AsyncClient(base_url=cls.get_env(cls._env))

    @classmethod
    def init_sync_client(cls):
        cls._sync_client = httpx.Client(base_url=cls.get_env(cls._env))

    def __init_subclass__(cls, env: str):
        cls._env = env
        cls.init_async_client()
        cls.init_sync_client()

    async def __aenter__(self):
        if self._async_client.is_closed:
            self.init_async_client()
        return self

    async def __aexit__(self, exc_type, exc_value, traceback):
        await self._async_client.aclose()

    def __enter__(self):
        if self._sync_client.is_closed:
            self.init_sync_client()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self._sync_client.close()

    @staticmethod
    def _try_raise_http_exception(resp: httpx.Response) -> None:
        try:
            resp.raise_for_status()
        except httpx.HTTPStatusError:
            raise HTTPException(
                500,
                content=f"Failed while requesting {resp.url}: {resp.status_code}\n\n {resp.text}",
            )
        except httpx.TransportError as e:
            raise HTTPException(
                500,
                content=f"Failed while requesting {resp.url}: {e}",
            )


P = ParamSpec("P")
R = TypeVar("R")


@dataclasses.dataclass
class IOCall(Generic[P, R]):
    Async: Callable[P, Awaitable[R]]
    Sync: Callable[P, R]

    cls: ClassVar[type[RemoteCall]]

    def __set_name__(self, owner, name):
        setattr(self, "cls", owner)


@dataclasses.dataclass
class Request:
    method: str
    url: str

    content: httpx._types.RequestContent | None = None
    data: httpx._types.RequestData | None = None
    files: httpx._types.RequestFiles | None = None
    json: Any | None = None
    params: httpx._types.QueryParamTypes | None = None
    headers: httpx._types.HeaderTypes | None = None
    cookies: httpx._types.CookieTypes | None = None
    timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = (
        httpx.USE_CLIENT_DEFAULT
    )
    extensions: httpx._types.RequestExtensions | None = None


Response = httpx.Response


G = Generator[Request, Response, R]


def convert(
    func: Callable[Concatenate[typing.Any, P], Generator[Request, Response, R]]
) -> IOCall[P, R]:
    async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
        cls = call.cls
        g = func(cls, *args, **kwargs)
        request = next(g)
        request = cls._async_client.build_request(**dataclasses.asdict(request))
        resp = await cls._async_client.send(request)
        cls._try_raise_http_exception(resp)
        try:
            g.send(resp)
        except StopIteration as exc:
            return exc.value
        raise RuntimeError("Generator did not stop")

    def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
        cls = call.cls
        g = func(cls, *args, **kwargs)
        request = next(g)
        request = cls._sync_client.build_request(**dataclasses.asdict(request))
        resp = cls._sync_client.send(request)
        cls._try_raise_http_exception(resp)
        try:
            g.send(resp)
        except StopIteration as exc:
            return exc.value
        raise RuntimeError("Generator did not stop")

    call = IOCall(async_wrapper, sync_wrapper)
    return call

接下来看看怎么使用这些东西编写业务逻辑。比较值得注意的是,第一个参数是 cls 而不是 self,调用时可以直接使用 Call.fetch_data.Sync(level) 拉取结果,或者使用异步代码 await Call.fetch_data.Async(level)

from . import RemoteCall, convert, G, Request


class Call(RemoteCall, env="LOG"):
    @convert
    def fetch_data(cls, level: str) -> G[str]:
        """
        fetch data
        """
        response = yield Request("POST", "/data", json={"level": level})
        ...
        return "xxx"
如果你觉得本文值得,不妨赏杯茶
在 1C 1G 的服务器上跑 mastodon
使用 zod 验证 Input file