import datetime
import re

import boto3
import pytest
from botocore.exceptions import ClientError

from moto import mock_aws
from moto.core import DEFAULT_ACCOUNT_ID as ACCOUNT_ID

FAKE_ROLE_ARN = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
FAKE_PROCESSING_JOB_NAME = "MyProcessingJob"
FAKE_CONTAINER = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1"
TEST_REGION_NAME = "us-east-1"


@pytest.fixture(name="sagemaker_client")
def fixture_sagemaker_client():
    with mock_aws():
        yield boto3.client("sagemaker", region_name=TEST_REGION_NAME)


class MyProcessingJobModel:
    def __init__(
        self,
        processing_job_name,
        role_arn,
        container=None,
        bucket=None,
        prefix=None,
        app_specification=None,
        network_config=None,
        processing_inputs=None,
        processing_output_config=None,
        processing_resources=None,
        stopping_condition=None,
    ):
        self.processing_job_name = processing_job_name
        self.role_arn = role_arn
        self.container = (
            container
            or "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-scikit-learn:0.23-1-cpu-py3"
        )
        self.bucket = bucket or "my-bucket"
        self.prefix = prefix or "sagemaker"
        self.app_specification = app_specification or {
            "ImageUri": self.container,
            "ContainerEntrypoint": ["python3"],
        }
        self.network_config = network_config or {
            "EnableInterContainerTrafficEncryption": False,
            "EnableNetworkIsolation": False,
        }
        self.processing_inputs = processing_inputs or [
            {
                "InputName": "input",
                "AppManaged": False,
                "S3Input": {
                    "S3Uri": f"s3://{self.bucket}/{self.prefix}/processing/",
                    "LocalPath": "/opt/ml/processing/input",
                    "S3DataType": "S3Prefix",
                    "S3InputMode": "File",
                    "S3DataDistributionType": "FullyReplicated",
                    "S3CompressionType": "None",
                },
            }
        ]
        self.processing_output_config = processing_output_config or {
            "Outputs": [
                {
                    "OutputName": "output",
                    "S3Output": {
                        "S3Uri": f"s3://{self.bucket}/{self.prefix}/processing/",
                        "LocalPath": "/opt/ml/processing/output",
                        "S3UploadMode": "EndOfJob",
                    },
                    "AppManaged": False,
                }
            ]
        }
        self.processing_resources = processing_resources or {
            "ClusterConfig": {
                "InstanceCount": 1,
                "InstanceType": "ml.m5.large",
                "VolumeSizeInGB": 10,
            },
        }
        self.stopping_condition = stopping_condition or {
            "MaxRuntimeInSeconds": 3600,
        }

    def save(self, sagemaker_client):
        params = {
            "AppSpecification": self.app_specification,
            "NetworkConfig": self.network_config,
            "ProcessingInputs": self.processing_inputs,
            "ProcessingJobName": self.processing_job_name,
            "ProcessingOutputConfig": self.processing_output_config,
            "ProcessingResources": self.processing_resources,
            "RoleArn": self.role_arn,
            "StoppingCondition": self.stopping_condition,
        }

        return sagemaker_client.create_processing_job(**params)


def test_create_processing_job(sagemaker_client):
    bucket = "my-bucket"
    prefix = "my-prefix"
    app_specification = {
        "ImageUri": FAKE_CONTAINER,
        "ContainerEntrypoint": ["python3", "app.py"],
    }
    processing_resources = {
        "ClusterConfig": {
            "InstanceCount": 2,
            "InstanceType": "ml.m5.xlarge",
            "VolumeSizeInGB": 20,
        },
    }
    stopping_condition = {"MaxRuntimeInSeconds": 60 * 60}

    job = MyProcessingJobModel(
        processing_job_name=FAKE_PROCESSING_JOB_NAME,
        role_arn=FAKE_ROLE_ARN,
        container=FAKE_CONTAINER,
        bucket=bucket,
        prefix=prefix,
        app_specification=app_specification,
        processing_resources=processing_resources,
        stopping_condition=stopping_condition,
    )
    resp = job.save(sagemaker_client)
    assert re.match(
        rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$",
        resp["ProcessingJobArn"],
    )

    resp = sagemaker_client.describe_processing_job(
        ProcessingJobName=FAKE_PROCESSING_JOB_NAME
    )
    assert resp["ProcessingJobName"] == FAKE_PROCESSING_JOB_NAME
    assert re.match(
        rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$",
        resp["ProcessingJobArn"],
    )
    assert "python3" in resp["AppSpecification"]["ContainerEntrypoint"]
    assert "app.py" in resp["AppSpecification"]["ContainerEntrypoint"]
    assert resp["RoleArn"] == FAKE_ROLE_ARN
    assert resp["ProcessingJobStatus"] == "Completed"
    assert isinstance(resp["CreationTime"], datetime.datetime)
    assert isinstance(resp["LastModifiedTime"], datetime.datetime)


def test_list_processing_jobs(sagemaker_client):
    test_processing_job = MyProcessingJobModel(
        processing_job_name=FAKE_PROCESSING_JOB_NAME, role_arn=FAKE_ROLE_ARN
    )
    test_processing_job.save(sagemaker_client)
    processing_jobs = sagemaker_client.list_processing_jobs()
    assert len(processing_jobs["ProcessingJobSummaries"]) == 1
    assert (
        processing_jobs["ProcessingJobSummaries"][0]["ProcessingJobName"]
        == FAKE_PROCESSING_JOB_NAME
    )

    assert re.match(
        rf"^arn:aws:sagemaker:.*:.*:processing-job/{FAKE_PROCESSING_JOB_NAME}$",
        processing_jobs["ProcessingJobSummaries"][0]["ProcessingJobArn"],
    )
    assert processing_jobs.get("NextToken") is None


def test_list_processing_jobs_multiple(sagemaker_client):
    name_job_1 = "blah"
    arn_job_1 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar"
    test_processing_job_1 = MyProcessingJobModel(
        processing_job_name=name_job_1, role_arn=arn_job_1
    )
    test_processing_job_1.save(sagemaker_client)

    name_job_2 = "blah2"
    arn_job_2 = "arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar2"
    test_processing_job_2 = MyProcessingJobModel(
        processing_job_name=name_job_2, role_arn=arn_job_2
    )
    test_processing_job_2.save(sagemaker_client)
    processing_jobs_limit = sagemaker_client.list_processing_jobs(MaxResults=1)
    assert len(processing_jobs_limit["ProcessingJobSummaries"]) == 1

    processing_jobs = sagemaker_client.list_processing_jobs()
    assert len(processing_jobs["ProcessingJobSummaries"]) == 2
    assert processing_jobs.get("NextToken") is None


def test_list_processing_jobs_none(sagemaker_client):
    processing_jobs = sagemaker_client.list_processing_jobs()
    assert len(processing_jobs["ProcessingJobSummaries"]) == 0


def test_list_processing_jobs_should_validate_input(sagemaker_client):
    junk_status_equals = "blah"
    with pytest.raises(ClientError) as ex:
        sagemaker_client.list_processing_jobs(StatusEquals=junk_status_equals)
    expected_error = (
        f"1 validation errors detected: Value '{junk_status_equals}' at "
        "'statusEquals' failed to satisfy constraint: Member must satisfy "
        "enum value set: ['Completed', 'Stopped', 'InProgress', 'Stopping', "
        "'Failed']"
    )
    assert ex.value.response["Error"]["Code"] == "ValidationException"
    assert ex.value.response["Error"]["Message"] == expected_error

    junk_next_token = "asdf"
    with pytest.raises(ClientError) as ex:
        sagemaker_client.list_processing_jobs(NextToken=junk_next_token)
    assert ex.value.response["Error"]["Code"] == "ValidationException"
    assert (
        ex.value.response["Error"]["Message"]
        == 'Invalid pagination token because "{0}".'
    )


def test_list_processing_jobs_with_name_filters(sagemaker_client):
    for i in range(5):
        name = f"xgboost-{i}"
        arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{i}"
        MyProcessingJobModel(processing_job_name=name, role_arn=arn).save(
            sagemaker_client
        )

    for i in range(5):
        name = f"vgg-{i}"
        arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{i}"
        MyProcessingJobModel(processing_job_name=name, role_arn=arn).save(
            sagemaker_client
        )

    xgboost_processing_jobs = sagemaker_client.list_processing_jobs(
        NameContains="xgboost"
    )
    assert len(xgboost_processing_jobs["ProcessingJobSummaries"]) == 5

    processing_jobs_with_2 = sagemaker_client.list_processing_jobs(NameContains="2")
    assert len(processing_jobs_with_2["ProcessingJobSummaries"]) == 2


def test_list_processing_jobs_paginated(sagemaker_client):
    for i in range(5):
        name = f"xgboost-{i}"
        arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{i}"
        MyProcessingJobModel(processing_job_name=name, role_arn=arn).save(
            sagemaker_client
        )

    xgboost_processing_job_1 = sagemaker_client.list_processing_jobs(
        NameContains="xgboost", MaxResults=1
    )
    assert len(xgboost_processing_job_1["ProcessingJobSummaries"]) == 1
    assert (
        xgboost_processing_job_1["ProcessingJobSummaries"][0]["ProcessingJobName"]
        == "xgboost-0"
    )
    assert xgboost_processing_job_1.get("NextToken") is not None

    xgboost_processing_job_next = sagemaker_client.list_processing_jobs(
        NameContains="xgboost",
        MaxResults=1,
        NextToken=xgboost_processing_job_1.get("NextToken"),
    )
    assert len(xgboost_processing_job_next["ProcessingJobSummaries"]) == 1
    assert (
        xgboost_processing_job_next["ProcessingJobSummaries"][0]["ProcessingJobName"]
        == "xgboost-1"
    )
    assert xgboost_processing_job_next.get("NextToken") is not None


def test_list_processing_jobs_paginated_with_target_in_middle(sagemaker_client):
    for i in range(5):
        name = f"xgboost-{i}"
        arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{i}"
        MyProcessingJobModel(processing_job_name=name, role_arn=arn).save(
            sagemaker_client
        )

    for i in range(5):
        name = f"vgg-{i}"
        arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{i}"
        MyProcessingJobModel(processing_job_name=name, role_arn=arn).save(
            sagemaker_client
        )

    vgg_processing_job_1 = sagemaker_client.list_processing_jobs(
        NameContains="vgg", MaxResults=1
    )
    assert len(vgg_processing_job_1["ProcessingJobSummaries"]) == 0
    assert vgg_processing_job_1.get("NextToken") is not None

    vgg_processing_job_6 = sagemaker_client.list_processing_jobs(
        NameContains="vgg", MaxResults=6
    )

    assert len(vgg_processing_job_6["ProcessingJobSummaries"]) == 1
    assert (
        vgg_processing_job_6["ProcessingJobSummaries"][0]["ProcessingJobName"]
        == "vgg-0"
    )
    assert vgg_processing_job_6.get("NextToken") is not None

    vgg_processing_job_10 = sagemaker_client.list_processing_jobs(
        NameContains="vgg", MaxResults=10
    )

    assert len(vgg_processing_job_10["ProcessingJobSummaries"]) == 5
    assert (
        vgg_processing_job_10["ProcessingJobSummaries"][-1]["ProcessingJobName"]
        == "vgg-4"
    )
    assert vgg_processing_job_10.get("NextToken") is None


def test_list_processing_jobs_paginated_with_fragmented_targets(sagemaker_client):
    for i in range(5):
        name = f"xgboost-{i}"
        arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/foobar-{i}"
        MyProcessingJobModel(processing_job_name=name, role_arn=arn).save(
            sagemaker_client
        )

    for i in range(5):
        name = f"vgg-{i}"
        arn = f"arn:aws:sagemaker:us-east-1:000000000000:x-x/barfoo-{i}"
        MyProcessingJobModel(processing_job_name=name, role_arn=arn).save(
            sagemaker_client
        )

    processing_jobs_with_2 = sagemaker_client.list_processing_jobs(
        NameContains="2", MaxResults=8
    )
    assert len(processing_jobs_with_2["ProcessingJobSummaries"]) == 2
    assert processing_jobs_with_2.get("NextToken") is not None

    processing_jobs_with_2_next = sagemaker_client.list_processing_jobs(
        NameContains="2",
        MaxResults=1,
        NextToken=processing_jobs_with_2.get("NextToken"),
    )
    assert len(processing_jobs_with_2_next["ProcessingJobSummaries"]) == 0
    assert processing_jobs_with_2_next.get("NextToken") is not None

    processing_jobs_with_2_next_next = sagemaker_client.list_processing_jobs(
        NameContains="2",
        MaxResults=1,
        NextToken=processing_jobs_with_2_next.get("NextToken"),
    )
    assert len(processing_jobs_with_2_next_next["ProcessingJobSummaries"]) == 0
    assert processing_jobs_with_2_next_next.get("NextToken") is None


def test_add_and_delete_tags_in_training_job(sagemaker_client):
    processing_job_name = "MyProcessingJob"
    role_arn = f"arn:aws:iam::{ACCOUNT_ID}:role/FakeRole"
    container = "382416733822.dkr.ecr.us-east-1.amazonaws.com/linear-learner:1"
    bucket = "my-bucket"
    prefix = "my-prefix"
    app_specification = {
        "ImageUri": container,
        "ContainerEntrypoint": ["python3", "app.py"],
    }
    processing_resources = {
        "ClusterConfig": {
            "InstanceCount": 2,
            "InstanceType": "ml.m5.xlarge",
            "VolumeSizeInGB": 20,
        },
    }
    stopping_condition = {"MaxRuntimeInSeconds": 60 * 60}

    job = MyProcessingJobModel(
        processing_job_name,
        role_arn,
        container=container,
        bucket=bucket,
        prefix=prefix,
        app_specification=app_specification,
        processing_resources=processing_resources,
        stopping_condition=stopping_condition,
    )
    resp = job.save(sagemaker_client)
    resource_arn = resp["ProcessingJobArn"]

    tags = [
        {"Key": "myKey", "Value": "myValue"},
    ]
    response = sagemaker_client.add_tags(ResourceArn=resource_arn, Tags=tags)
    assert response["ResponseMetadata"]["HTTPStatusCode"] == 200

    response = sagemaker_client.list_tags(ResourceArn=resource_arn)
    assert response["Tags"] == tags

    tag_keys = [tag["Key"] for tag in tags]
    response = sagemaker_client.delete_tags(ResourceArn=resource_arn, TagKeys=tag_keys)
    assert response["ResponseMetadata"]["HTTPStatusCode"] == 200

    response = sagemaker_client.list_tags(ResourceArn=resource_arn)
    assert response["Tags"] == []
