File: test_dmac_compose_model_async.py

package info (click to toggle)
python-azure 20250603%2Bgit-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 851,724 kB
  • sloc: python: 7,362,925; ansic: 804; javascript: 287; makefile: 195; sh: 145; xml: 109
file content (111 lines) | stat: -rw-r--r-- 4,925 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# coding=utf-8
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------

import uuid
import functools
from devtools_testutils.aio import recorded_by_proxy_async
from devtools_testutils import set_bodiless_matcher
from azure.ai.documentintelligence.aio import DocumentIntelligenceAdministrationClient
from azure.ai.documentintelligence.models import (
    AzureBlobContentSource,
    BuildDocumentModelRequest,
    BuildDocumentClassifierRequest,
    ComposeDocumentModelRequest,
    ClassifierDocumentTypeDetails,
    DocumentTypeDetails,
)
from asynctestcase import AsyncDocumentIntelligenceTest
from conftest import skip_flaky_test
from preparers import DocumentIntelligencePreparer, GlobalClientPreparerAsync as _GlobalClientPreparer

DocumentModelAdministrationClientPreparer = functools.partial(
    _GlobalClientPreparer, DocumentIntelligenceAdministrationClient
)


class TestTrainingAsync(AsyncDocumentIntelligenceTest):
    @skip_flaky_test
    @DocumentIntelligencePreparer()
    @DocumentModelAdministrationClientPreparer()
    @recorded_by_proxy_async
    async def test_compose_model(self, client, documentintelligence_training_data_classifier_sas_url, **kwargs):
        set_bodiless_matcher()
        recorded_variables = kwargs.pop("variables", {})
        recorded_variables.setdefault("model_id1", str(uuid.uuid4()))
        recorded_variables.setdefault("model_id2", str(uuid.uuid4()))
        recorded_variables.setdefault("composed_id", str(uuid.uuid4()))

        async with client:
            request = BuildDocumentModelRequest(
                model_id=recorded_variables.get("model_id1"),
                description="model1",
                build_mode="template",
                azure_blob_source=AzureBlobContentSource(
                    container_url=documentintelligence_training_data_classifier_sas_url
                ),
            )
            poller = await client.begin_build_document_model(request)
            model_1 = await poller.result()

            request = BuildDocumentModelRequest(
                model_id=recorded_variables.get("model_id2"),
                description="model2",
                build_mode="template",
                azure_blob_source=AzureBlobContentSource(
                    container_url=documentintelligence_training_data_classifier_sas_url
                ),
            )
            poller = await client.begin_build_document_model(request)
            model_2 = await poller.result()

            request = BuildDocumentClassifierRequest(
                classifier_id=classifier.classifier_id,
                description="IRS document classifier",
                doc_types={
                    "IRS-1040-A": ClassifierDocumentTypeDetails(
                        azure_blob_source=AzureBlobContentSource(
                            container_url=documentintelligence_training_data_classifier_sas_url,
                            prefix="IRS-1040-A/train",
                        )
                    ),
                    "IRS-1040-B": ClassifierDocumentTypeDetails(
                        azure_blob_source=AzureBlobContentSource(
                            container_url=documentintelligence_training_data_classifier_sas_url,
                            prefix="IRS-1040-B/train",
                        )
                    ),
                },
            )
            poller = await client.begin_build_classifier(request)
            classifier = await poller.result()
            classifier_id = classifier.classifier_id

            request = ComposeDocumentModelRequest(
                model_id=recorded_variables.get("composed_id"),
                classifier_id=classifier_id,
                description="my composed model",
                tags={"testkey": "testvalue"},
                doc_types={
                    "formA": DocumentTypeDetails(model_id=model_1.model_id),
                    "formA": DocumentTypeDetails(model_id=model_2.model_id),
                },
            )
            poller = await client.begin_compose_model(request)
            composed_model = await poller.result()
            assert composed_model.api_version
            assert composed_model.model_id == recorded_variables.get("composed_id")
            assert composed_model.description == "my composed model"
            assert composed_model.created_date_time
            assert composed_model.expiration_date_time
            assert composed_model.tags == {"testkey": "testvalue"}
            for name, doc_details in composed_model.doc_types.items():
                assert name
                for key, field in doc_details.field_schema.items():
                    assert key
                    assert field["type"]
                    assert doc_details.field_confidence[key] is not None

        return recorded_variables