File: test_assistants_async.py

package info (click to toggle)
python-azure 20250603%2Bgit-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 851,724 kB
  • sloc: python: 7,362,925; ansic: 804; javascript: 287; makefile: 195; sh: 145; xml: 109
file content (639 lines) | stat: -rw-r--r-- 27,353 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------

import os
import pytest
import pathlib
import uuid
import openai
from devtools_testutils import AzureRecordedTestCase
from conftest import ASST_AZURE, PREVIEW, GPT_4_OPENAI, configure_async
from openai import AsyncAssistantEventHandler
from openai.types.beta.threads import (
    Text,
    Message,
    ImageFile,
    TextDelta,
    MessageDelta,
)
from openai.types.beta.threads import Run
from openai.types.beta.threads.runs import RunStep, ToolCall, RunStepDelta, ToolCallDelta


class AsyncEventHandler(AsyncAssistantEventHandler):
    async def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
        if delta.value:
            assert delta.value is not None
        if delta.annotations:
            for annotation in delta.annotations:
                if annotation.type == "file_citation":
                    assert annotation.index is not None
                    assert annotation.file_citation.file_id
                    assert annotation.file_citation.quote
                elif annotation.type == "file_path":
                    assert annotation.index is not None
                    assert annotation.file_path.file_id

    async def on_run_step_done(self, run_step: RunStep) -> None:
        details = run_step.step_details
        if details.type == "tool_calls":
            for tool in details.tool_calls:
                if tool.type == "code_interpreter":
                    assert tool.id
                    assert tool.code_interpreter.input is not None
                elif tool.type == "function":
                    assert tool.id
                    assert tool.function.arguments is not None
                    assert tool.function.name is not None

    async def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None:
        details = delta.step_details
        if details is not None:
            if details.type == "tool_calls":
                for tool in details.tool_calls or []:
                    if tool.type == "code_interpreter" and tool.code_interpreter and tool.code_interpreter.input:
                        assert tool.index is not None
                        assert tool.code_interpreter.input is not None
            elif details.type == "message_creation":
                assert details.message_creation.message_id

    async def on_run_step_created(self, run_step: RunStep):
        assert run_step.object == "thread.run.step"
        assert run_step.id
        assert run_step.type
        assert run_step.created_at
        assert run_step.assistant_id
        assert run_step.thread_id
        assert run_step.run_id
        assert run_step.status
        assert run_step.step_details

    async def on_message_created(self, message: Message):
        assert message.object == "thread.message"
        assert message.id
        assert message.created_at
        assert message.attachments is not None
        assert message.status
        assert message.thread_id

    async def on_message_delta(self, delta: MessageDelta, snapshot: Message):
        if delta.content:
            for content in delta.content:
                if content.type == "text":
                    assert content.index is not None
                    if content.text:
                        if content.text.value:
                            assert content.text.value is not None
                        if content.text.annotations:
                            for annotation in content.text.annotations:
                                if annotation.type == "file_citation":
                                    assert annotation.end_index is not None
                                    assert annotation.file_citation.file_id
                                    assert annotation.file_citation.quote
                                    assert annotation.start_index is not None
                                elif annotation.type == "file_path":
                                    assert annotation.end_index is not None
                                    assert annotation.file_path.file_id
                                    assert annotation.start_index is not None
                elif content.type == "image_file":
                    assert content.index is not None
                    assert content.image_file.file_id

    async def on_message_done(self, message: Message):
        for msg in message.content:
            if msg.type == "image_file":
                assert msg.image_file.file_id
            if msg.type == "text":
                assert msg.text.value
                if msg.text.annotations:
                    for annotation in msg.text.annotations:
                        if annotation.type == "file_citation":
                            assert annotation.end_index is not None
                            assert annotation.file_citation.file_id
                            assert annotation.file_citation.quote
                            assert annotation.start_index is not None
                            assert annotation.text is not None
                        elif annotation.type == "file_path":
                            assert annotation.end_index is not None
                            assert annotation.file_path.file_id
                            assert annotation.start_index is not None
                            assert annotation.text is not None

    async def on_text_created(self, text: Text):
        assert text.value is not None

    async def on_text_done(self, text: Text):
        assert text.value  is not None
        for annotation in text.annotations:
            if annotation.type == "file_citation":
                assert annotation.end_index is not None
                assert annotation.file_citation.file_id
                assert annotation.file_citation.quote
                assert annotation.start_index is not None
                assert annotation.text is not None
            elif annotation.type == "file_path":
                assert annotation.end_index is not None
                assert annotation.file_path.file_id
                assert annotation.start_index is not None
                assert annotation.text is not None

    async def on_image_file_done(self, image_file: ImageFile):
        assert image_file.file_id

    async def on_tool_call_created(self, tool_call: ToolCall):
        assert tool_call.id

    async def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall):
        if delta.type == "code_interpreter":
            assert delta.index is not None
            if delta.code_interpreter:
                if delta.code_interpreter.input:
                    assert delta.code_interpreter.input is not None
            if delta.code_interpreter.outputs:
                for output in delta.code_interpreter.outputs:
                    if output.type == "image":
                        assert output.image.file_id
                    elif output.type == "logs":
                        assert output.logs
        if delta.type == "function":
            assert delta.id
            if delta.function:
                assert delta.function.arguments is not None
                assert delta.function.name is not None

    async def on_tool_call_done(self, tool_call: ToolCall):
        if tool_call.type == "code_interpreter":
            assert tool_call.id
            assert tool_call.code_interpreter.input is not None
            for output in tool_call.code_interpreter.outputs:
                if output.type == "image":
                    assert output.image.file_id
                elif output.type == "logs":
                    assert output.logs
        if tool_call.type == "function":
            assert tool_call.id
            assert tool_call.function.arguments is not None
            assert tool_call.function.name is not None


@pytest.mark.live_test_only
class TestAssistantsAsync(AzureRecordedTestCase):

    def handle_run_failure(self, run: Run):
        if run.status == "failed":
            if "Rate limit" in run.last_error.message:
                pytest.skip("Skipping - Rate limit reached.")
            raise openai.OpenAIError(run.last_error.message)
        if run.status not in ["completed", "requires_action"]:
            raise openai.OpenAIError(f"Run in unexpected status: {run.status}")

    @configure_async
    @pytest.mark.asyncio
    @pytest.mark.parametrize(
        "api_type, api_version",
        [(ASST_AZURE, PREVIEW)]
    )
    async def test_assistants_crud(self, client_async, api_type, api_version, **kwargs):
        try:
            assistant = await client_async.beta.assistants.create(
                name="python test",
                instructions="You are a personal math tutor. Write and run code to answer math questions.",
                tools=[{"type": "code_interpreter"}],
                model="gpt-4-1106-preview",
            )
            retrieved_assistant = await client_async.beta.assistants.retrieve(
                assistant_id=assistant.id,
            )
            assert retrieved_assistant.id == assistant.id
            assert retrieved_assistant.name == assistant.name
            assert retrieved_assistant.instructions == assistant.instructions
            assert retrieved_assistant.tools == assistant.tools
            assert retrieved_assistant.model == assistant.model
            assert retrieved_assistant.created_at == assistant.created_at
            assert retrieved_assistant.description == assistant.description
            assert retrieved_assistant.metadata == assistant.metadata
            assert retrieved_assistant.object == assistant.object

            list_assistants = client_async.beta.assistants.list()
            async for asst in list_assistants:
                assert asst.id

            modify_assistant = await client_async.beta.assistants.update(
                assistant_id=assistant.id,
                metadata={"key": "value"}
            )
            assert modify_assistant.metadata == {"key": "value"}
        finally:
            delete_assistant = await client_async.beta.assistants.delete(
                assistant_id=assistant.id
            )
            assert delete_assistant.id == assistant.id
            assert delete_assistant.deleted is True

    @configure_async
    @pytest.mark.asyncio
    @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")])
    async def test_assistants_threads_crud(self, client_async, api_type, api_version, **kwargs):
        try:
            thread = await client_async.beta.threads.create(
                messages=[
                    {
                        "role": "user",
                        "content": "I need help with math homework",
                    }
                ],
                metadata={"key": "value"},
            )
            retrieved_thread = await client_async.beta.threads.retrieve(
                thread_id=thread.id,
            )
            assert retrieved_thread.id == thread.id
            assert retrieved_thread.object == thread.object
            assert retrieved_thread.created_at == thread.created_at
            assert retrieved_thread.metadata == thread.metadata

            updated_thread = await client_async.beta.threads.update(
                thread_id=thread.id,
                metadata={"key": "updated"}
            )
            assert updated_thread.metadata == {"key": "updated"}

        finally:
            delete_thread = await client_async.beta.threads.delete(
                thread_id=thread.id
            )
            assert delete_thread.id == thread.id
            assert delete_thread.deleted is True

    @configure_async
    @pytest.mark.asyncio
    @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")])
    async def test_assistants_messages_crud(self, client_async, api_type, api_version, **kwargs):
        file_name = f"test{uuid.uuid4()}.txt"
        with open(file_name, "w") as f:
            f.write("test")

        path = pathlib.Path(file_name)

        file = await client_async.files.create(
            file=open(path, "rb"),
            purpose="assistants"
        )

        try:
            thread = await client_async.beta.threads.create(
                messages=[
                    {
                        "role": "user",
                        "content": "I need help with math homework",
                    }
                ],
                metadata={"key": "value"},
            )

            message = await client_async.beta.threads.messages.create(
                thread_id=thread.id,
                role="user",
                content="what is 2+2?",
                metadata={"math": "addition"},
                attachments=[
                    {
                        "file_id": file.id,
                        "tools": [{"type": "code_interpreter"}]
                    }
                ]
            )
            retrieved_message = await client_async.beta.threads.messages.retrieve(
                thread_id=thread.id,
                message_id=message.id
            )
            assert retrieved_message.id == message.id
            assert retrieved_message.created_at == message.created_at
            assert retrieved_message.metadata == message.metadata
            assert retrieved_message.object == message.object
            assert retrieved_message.thread_id == thread.id
            assert retrieved_message.role == message.role
            assert retrieved_message.content == message.content

            list_messages = client_async.beta.threads.messages.list(
                thread_id=thread.id
            )
            async for msg in list_messages:
                assert msg.id

            modify_message = await client_async.beta.threads.messages.update(
                thread_id=thread.id,
                message_id=message.id,
                metadata={"math": "updated"}
            )
            assert modify_message.metadata == {"math": "updated"}

        finally:
            os.remove(path)
            delete_thread = await client_async.beta.threads.delete(
                thread_id=thread.id
            )
            assert delete_thread.id == thread.id
            assert delete_thread.deleted is True
            delete_file = await client_async.files.delete(file.id)
            assert delete_file.deleted is True

    @configure_async
    @pytest.mark.asyncio
    @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")])
    async def test_assistants_runs_code(self, client_async, api_type, api_version, **kwargs):
        try:
            assistant = await client_async.beta.assistants.create(
                name="python test",
                instructions="You are a personal math tutor. Write and run code to answer math questions.",
                tools=[{"type": "code_interpreter"}],
                model="gpt-4-1106-preview",
            )
            thread = await client_async.beta.threads.create()

            message = await client_async.beta.threads.messages.create(
                thread_id=thread.id,
                role="user",
                content="I need to solve the equation `3x + 11 = 14`. Can you help me?",
            )

            run = await client_async.beta.threads.runs.create_and_poll(
                thread_id=thread.id,
                assistant_id=assistant.id,
                instructions="Please address the user as Jane Doe.",
                additional_instructions="After solving each equation, say 'Isn't math fun?'",
            )
            self.handle_run_failure(run)
            if run.status == "completed":
                messages = client_async.beta.threads.messages.list(thread_id=thread.id)

                async for message in messages:
                    assert message.content[0].type == "text"
                    assert message.content[0].text.value

            run = await client_async.beta.threads.runs.update(
                thread_id=thread.id,
                run_id=run.id,
                metadata={"user": "user123"}
            )
            assert run.metadata == {"user": "user123"}

        finally:
            delete_assistant = await client_async.beta.assistants.delete(
                assistant_id=assistant.id
            )
            assert delete_assistant.id == assistant.id
            assert delete_assistant.deleted is True

            delete_thread = await client_async.beta.threads.delete(
                thread_id=thread.id
            )
            assert delete_thread.id == thread.id
            assert delete_thread.deleted is True

    @configure_async
    @pytest.mark.asyncio
    @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")])
    async def test_assistants_runs_file_search(self, client_async, api_type, api_version, **kwargs):
        file_name = f"test{uuid.uuid4()}.txt"
        with open(file_name, "w") as f:
            f.write("Contoso company policy requires that all employees take at least 10 vacation days a year.")

        path = pathlib.Path(file_name)

        try:
            vector_store = await client_async.vector_stores.create(
                name="Support FAQ",
            )
            await client_async.vector_stores.files.upload_and_poll(
                vector_store_id=vector_store.id,
                file=path
            )
            assistant = await client_async.beta.assistants.create(
                name="python test",
                instructions="You help answer questions about Contoso company policy.",
                tools=[{"type": "file_search"}],
                tool_resources={
                    "file_search": {
                        "vector_store_ids": [vector_store.id]
                    }
                },
                model="gpt-4-1106-preview"
            )
            thread = await client_async.beta.threads.create(
                messages=[
                    {"role": "user", "content": "How many vacation days am I required to take as a Contoso employee?"}
                ]
            )

            run = await client_async.beta.threads.runs.create_and_poll(
                assistant_id=assistant.id,
                thread_id=thread.id,
            )
            self.handle_run_failure(run)

            run_steps = client_async.beta.threads.runs.steps.list(
                thread_id=thread.id,
                run_id=run.id,
                include=["step_details.tool_calls[*].file_search.results[*].content"]
            )
            async for step in run_steps:
                assert step
                if step.step_details.type == "tool_calls":
                    assert step.step_details.tool_calls[0].file_search.results[0].content[0].text

            if run.status == "completed":
                messages = client_async.beta.threads.messages.list(thread_id=run.thread_id)

                async for message in messages:
                    assert message.content[0].type == "text"
                    assert message.content[0].text.value

        finally:
            os.remove(path)
            delete_assistant = await client_async.beta.assistants.delete(
                assistant_id=assistant.id
            )
            assert delete_assistant.id == assistant.id
            assert delete_assistant.deleted is True

            delete_thread = await client_async.beta.threads.delete(
                thread_id=run.thread_id
            )
            assert delete_thread.id
            assert delete_thread.deleted is True
            deleted_vector_store = await client_async.vector_stores.delete(
                vector_store_id=vector_store.id
            )
            assert deleted_vector_store.deleted is True

    @configure_async
    @pytest.mark.asyncio
    @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")])
    async def test_assistants_runs_functions(self, client_async, api_type, api_version, **kwargs):
        try:
            assistant = await client_async.beta.assistants.create(
                name="python test",
                instructions="You help answer questions about the weather.",
                tools=[
                    {
                        "type": "function",
                        "function": {
                            "name": "get_current_weather",
                            "description": "Get the current weather",
                            "parameters": {
                                "type": "object",
                                "properties": {
                                    "location": {
                                        "type": "string",
                                        "description": "The city and state, e.g. San Francisco, CA",
                                    },
                                    "format": {
                                        "type": "string",
                                        "enum": ["celsius", "fahrenheit"],
                                        "description": "The temperature unit to use. Infer this from the users location.",
                                    },
                                },
                                "required": ["location"],
                            }
                        }
                    }
                ],
                model="gpt-4-1106-preview",
            )

            run = await client_async.beta.threads.create_and_run_poll(
                assistant_id=assistant.id,
                thread={
                    "messages": [
                        {"role": "user", "content": "How's the weather in Seattle?"}
                    ]
                }
            )
            self.handle_run_failure(run)
            if run.status == "requires_action":
                run = await client_async.beta.threads.runs.submit_tool_outputs_and_poll(
                    thread_id=run.thread_id,
                    run_id=run.id,
                    tool_outputs=[
                        {
                            "tool_call_id": run.required_action.submit_tool_outputs.tool_calls[0].id,
                            "output": "{\"temperature\": \"22\", \"unit\": \"celsius\", \"description\": \"Sunny\"}"
                        }
                    ]
                )
            self.handle_run_failure(run)
            if run.status == "completed":
                messages = client_async.beta.threads.messages.list(thread_id=run.thread_id)

                async for message in messages:
                    assert message.content[0].type == "text"
                    assert message.content[0].text.value


            runs = client_async.beta.threads.runs.list(thread_id=run.thread_id)
            async for r in runs:
                assert r.id == run.id
                assert r.thread_id == run.thread_id
                assert r.assistant_id == run.assistant_id
                assert r.created_at == run.created_at
                assert r.instructions == run.instructions
                assert r.tools == run.tools
                assert r.metadata == run.metadata

                run_steps = client_async.beta.threads.runs.steps.list(
                    thread_id=run.thread_id,
                    run_id=r.id
                )
                async for step in run_steps:
                    assert step.id

                retrieved_step = await client_async.beta.threads.runs.steps.retrieve(
                    thread_id=run.thread_id,
                    run_id=r.id,
                    step_id=step.id
                )
                assert retrieved_step.id
                assert retrieved_step.created_at
                assert retrieved_step.run_id
                assert retrieved_step.thread_id
                assert retrieved_step.assistant_id
                assert retrieved_step.type
                assert retrieved_step.step_details

        finally:
            delete_assistant = await client_async.beta.assistants.delete(
                assistant_id=assistant.id
            )
            assert delete_assistant.id == assistant.id
            assert delete_assistant.deleted is True

            delete_thread = await client_async.beta.threads.delete(
                thread_id=run.thread_id
            )
            assert delete_thread.id
            assert delete_thread.deleted is True

    @configure_async
    @pytest.mark.asyncio
    @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")])
    async def test_assistants_streaming(self, client_async, api_type, api_version, **kwargs):
        assistant = await client_async.beta.assistants.create(
            name="Math Tutor",
            instructions="You are a personal math tutor. Write and run code to answer math questions.",
            tools=[{"type": "code_interpreter"}],
            model="gpt-4-1106-preview",
        )
        try:
            thread = await client_async.beta.threads.create()
            await client_async.beta.threads.messages.create(
                thread_id=thread.id,
                role="user",
                content="I need to solve the equation `3x + 11 = 14`. Can you help me?",
            )
            stream = await client_async.beta.threads.runs.create(
                thread_id=thread.id,
                assistant_id=assistant.id,
                instructions="Please address the user as Jane Doe. The user has a premium account.",
                stream=True,
            )

            async for event in stream:
                assert event
        finally:
            await client_async.beta.assistants.delete(assistant.id)

    @configure_async
    @pytest.mark.asyncio
    @pytest.mark.parametrize("api_type, api_version", [(ASST_AZURE, PREVIEW), (GPT_4_OPENAI, "v1")])
    async def test_assistants_stream_event_handler(self, client_async, api_type, api_version, **kwargs):
        assistant = await client_async.beta.assistants.create(
            name="Math Tutor",
            instructions="You are a personal math tutor. Write and run code to answer math questions.",
            tools=[{"type": "code_interpreter"}],
            model="gpt-4-1106-preview"
        )

        try:
            question = "I need to solve the equation `3x + 11 = 14`. Can you help me and then generate an image with the answer?"

            thread = await client_async.beta.threads.create(
                messages=[
                    {
                        "role": "user",
                        "content": question,
                    },
                ]
            )

            async with client_async.beta.threads.runs.stream(
                thread_id=thread.id,
                assistant_id=assistant.id,
                instructions="Please address the user as Jane Doe. The user has a premium account.",
                event_handler=AsyncEventHandler(),
            ) as stream:
                await stream.until_done()
        finally:
            await client_async.beta.assistants.delete(assistant.id)