1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
|
from __future__ import annotations
import sys
import inspect
import typing_extensions
from typing import get_args
import pytest
from openai import OpenAI, AsyncOpenAI
from tests.utils import evaluate_forwardref
from openai._utils import assert_signatures_in_sync
from openai._compat import is_literal_type
from openai._utils._typing import is_union_type
from openai.types.audio_response_format import AudioResponseFormat
@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
def test_translation_create_overloads_in_sync(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
checking_client: OpenAI | AsyncOpenAI = client if sync else async_client
fn = checking_client.audio.translations.create
overload_response_formats: set[str] = set()
for i, overload in enumerate(typing_extensions.get_overloads(fn)):
assert_signatures_in_sync(
fn,
overload,
exclude_params={"response_format", "stream"},
description=f" for overload {i}",
)
sig = inspect.signature(overload)
typ = evaluate_forwardref(
sig.parameters["response_format"].annotation,
globalns=sys.modules[fn.__module__].__dict__,
)
if is_union_type(typ):
for arg in get_args(typ):
if not is_literal_type(arg):
continue
overload_response_formats.update(get_args(arg))
elif is_literal_type(typ):
overload_response_formats.update(get_args(typ))
src_response_formats: set[str] = set(get_args(AudioResponseFormat))
diff = src_response_formats.difference(overload_response_formats)
assert len(diff) == 0, f"some response format options don't have overloads"
@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
def test_transcription_create_overloads_in_sync(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
checking_client: OpenAI | AsyncOpenAI = client if sync else async_client
fn = checking_client.audio.transcriptions.create
overload_response_formats: set[str] = set()
for i, overload in enumerate(typing_extensions.get_overloads(fn)):
assert_signatures_in_sync(
fn,
overload,
exclude_params={"response_format", "stream"},
description=f" for overload {i}",
)
sig = inspect.signature(overload)
typ = evaluate_forwardref(
sig.parameters["response_format"].annotation,
globalns=sys.modules[fn.__module__].__dict__,
)
if is_union_type(typ):
for arg in get_args(typ):
if not is_literal_type(arg):
continue
overload_response_formats.update(get_args(arg))
elif is_literal_type(typ):
overload_response_formats.update(get_args(typ))
src_response_formats: set[str] = set(get_args(AudioResponseFormat))
diff = src_response_formats.difference(overload_response_formats)
assert len(diff) == 0, f"some response format options don't have overloads"
|