1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
|
# coding=utf-8
# ------------------------------------
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
import uuid
import functools
from devtools_testutils.aio import recorded_by_proxy_async
from devtools_testutils import set_bodiless_matcher
from azure.ai.documentintelligence.aio import DocumentIntelligenceAdministrationClient
from azure.ai.documentintelligence.models import (
AzureBlobContentSource,
BuildDocumentModelRequest,
BuildDocumentClassifierRequest,
ComposeDocumentModelRequest,
ClassifierDocumentTypeDetails,
DocumentTypeDetails,
)
from asynctestcase import AsyncDocumentIntelligenceTest
from conftest import skip_flaky_test
from preparers import DocumentIntelligencePreparer, GlobalClientPreparerAsync as _GlobalClientPreparer
DocumentModelAdministrationClientPreparer = functools.partial(
_GlobalClientPreparer, DocumentIntelligenceAdministrationClient
)
class TestTrainingAsync(AsyncDocumentIntelligenceTest):
@skip_flaky_test
@DocumentIntelligencePreparer()
@DocumentModelAdministrationClientPreparer()
@recorded_by_proxy_async
async def test_compose_model(self, client, documentintelligence_training_data_classifier_sas_url, **kwargs):
set_bodiless_matcher()
recorded_variables = kwargs.pop("variables", {})
recorded_variables.setdefault("model_id1", str(uuid.uuid4()))
recorded_variables.setdefault("model_id2", str(uuid.uuid4()))
recorded_variables.setdefault("composed_id", str(uuid.uuid4()))
async with client:
request = BuildDocumentModelRequest(
model_id=recorded_variables.get("model_id1"),
description="model1",
build_mode="template",
azure_blob_source=AzureBlobContentSource(
container_url=documentintelligence_training_data_classifier_sas_url
),
)
poller = await client.begin_build_document_model(request)
model_1 = await poller.result()
request = BuildDocumentModelRequest(
model_id=recorded_variables.get("model_id2"),
description="model2",
build_mode="template",
azure_blob_source=AzureBlobContentSource(
container_url=documentintelligence_training_data_classifier_sas_url
),
)
poller = await client.begin_build_document_model(request)
model_2 = await poller.result()
request = BuildDocumentClassifierRequest(
classifier_id=classifier.classifier_id,
description="IRS document classifier",
doc_types={
"IRS-1040-A": ClassifierDocumentTypeDetails(
azure_blob_source=AzureBlobContentSource(
container_url=documentintelligence_training_data_classifier_sas_url,
prefix="IRS-1040-A/train",
)
),
"IRS-1040-B": ClassifierDocumentTypeDetails(
azure_blob_source=AzureBlobContentSource(
container_url=documentintelligence_training_data_classifier_sas_url,
prefix="IRS-1040-B/train",
)
),
},
)
poller = await client.begin_build_classifier(request)
classifier = await poller.result()
classifier_id = classifier.classifier_id
request = ComposeDocumentModelRequest(
model_id=recorded_variables.get("composed_id"),
classifier_id=classifier_id,
description="my composed model",
tags={"testkey": "testvalue"},
doc_types={
"formA": DocumentTypeDetails(model_id=model_1.model_id),
"formA": DocumentTypeDetails(model_id=model_2.model_id),
},
)
poller = await client.begin_compose_model(request)
composed_model = await poller.result()
assert composed_model.api_version
assert composed_model.model_id == recorded_variables.get("composed_id")
assert composed_model.description == "my composed model"
assert composed_model.created_date_time
assert composed_model.expiration_date_time
assert composed_model.tags == {"testkey": "testvalue"}
for name, doc_details in composed_model.doc_types.items():
assert name
for key, field in doc_details.field_schema.items():
assert key
assert field["type"]
assert doc_details.field_confidence[key] is not None
return recorded_variables
|