import itertools
import sys

# Require Python 3.7+ for ordered dictionaries so that the order of the
# generated tests remain the same.
# Usage:
# python3.7 mongos-pin-auto-tests.py > mongos-pin-auto.yml
if sys.version_info[:2] < (3, 7):
    print('ERROR: This script requires Python >= 3.7, not:')
    print(sys.version)
    print('Usage: python3.7 mongos-pin-auto-tests.py > mongos-pin-auto.yml')
    exit(1)

HEADER = '''# Autogenerated tests that transient errors in a transaction unpin the session.
# See mongos-pin-auto-tests.py
runOn:
    -
        minServerVersion: "4.1.8"
        topology: ["sharded"]
        # serverless proxy doesn't append error labels to errors in transactions
        # caused by failpoints (CLOUDP-88216)
        serverless: "forbid"

database_name: &database_name "transaction-tests"
collection_name: &collection_name "test"

data: &data
  - {_id: 1}
  - {_id: 2}

tests:
  - description: remain pinned after non-transient Interrupted error on insertOne
    useMultipleMongoses: true
    operations:
      - &startTransaction
        name: startTransaction
        object: session0
      - &initialCommand
        name: insertOne
        object: collection
        arguments:
          session: session0
          document: {_id: 3}
        result:
          insertedId: 3
      - name: targetedFailPoint
        object: testRunner
        arguments:
          session: session0
          failPoint:
            configureFailPoint: failCommand
            mode: {times: 1}
            data:
              failCommands: ["insert"]
              errorCode: 11601
      - name: insertOne
        object: collection
        arguments:
          session: session0
          document:
            _id: 4
        result:
          errorLabelsOmit: ["TransientTransactionError", "UnknownTransactionCommitResult"]
          errorCodeName: Interrupted
      - &assertSessionPinned
        name: assertSessionPinned
        object: testRunner
        arguments:
          session: session0
      - &commitTransaction
        name: commitTransaction
        object: session0

    expectations:
      - command_started_event:
          command:
            insert: *collection_name
            documents:
              - _id: 3
            ordered: true
            readConcern:
            lsid: session0
            txnNumber:
              $numberLong: "1"
            startTransaction: true
            autocommit: false
            writeConcern:
          command_name: insert
          database_name: *database_name
      - command_started_event:
          command:
            insert: *collection_name
            documents:
              - _id: 4
            ordered: true
            readConcern:
            lsid: session0
            txnNumber:
              $numberLong: "1"
            startTransaction:
            autocommit: false
            writeConcern:
          command_name: insert
          database_name: *database_name
      - command_started_event:
          command:
            commitTransaction: 1
            lsid: session0
            txnNumber:
              $numberLong: "1"
            startTransaction:
            autocommit: false
            writeConcern:
            recoveryToken: 42
          command_name: commitTransaction
          database_name: admin

    outcome: &outcome
      collection:
        data:
          - {_id: 1}
          - {_id: 2}
          - {_id: 3}

  - description: unpin after transient error within a transaction
    useMultipleMongoses: true
    operations:
      - &startTransaction
        name: startTransaction
        object: session0
      - &initialCommand
        name: insertOne
        object: collection
        arguments:
          session: session0
          document:
            _id: 3
        result:
          insertedId: 3
      - name: targetedFailPoint
        object: testRunner
        arguments:
          session: session0
          failPoint:
            configureFailPoint: failCommand
            mode: { times: 1 }
            data:
              failCommands: ["insert"]
              closeConnection: true
      - name: insertOne
        object: collection
        arguments:
          session: session0
          document:
            _id: 4
        result:
          errorLabelsContain: ["TransientTransactionError"]
          errorLabelsOmit: ["UnknownTransactionCommitResult"]
      # Session unpins from the first mongos after the insert error and
      # abortTransaction succeeds immediately on any mongos.
      - &assertSessionUnpinned
        name: assertSessionUnpinned
        object: testRunner
        arguments:
          session: session0
      - &abortTransaction
        name: abortTransaction
        object: session0

    expectations:
      - command_started_event:
          command:
            insert: *collection_name
            documents:
              - _id: 3
            ordered: true
            readConcern:
            lsid: session0
            txnNumber:
              $numberLong: "1"
            startTransaction: true
            autocommit: false
            writeConcern:
          command_name: insert
          database_name: *database_name
      - command_started_event:
          command:
            insert: *collection_name
            documents:
              - _id: 4
            ordered: true
            readConcern:
            lsid: session0
            txnNumber:
              $numberLong: "1"
            startTransaction:
            autocommit: false
            writeConcern:
          command_name: insert
          database_name: *database_name
      - command_started_event:
          command:
            abortTransaction: 1
            lsid: session0
            txnNumber:
              $numberLong: "1"
            startTransaction:
            autocommit: false
            writeConcern:
            recoveryToken: 42
          command_name: abortTransaction
          database_name: admin

    outcome: &outcome
      collection:
        data: *data

  # The rest of the tests in this file test every operation type against
  # multiple types of transient errors (connection and error code).'''

TEMPLATE = '''
  - description: {test_name} {error_name} error on {op_name} {command_name}
    useMultipleMongoses: true
    operations:
      - *startTransaction
      - *initialCommand
      - name: targetedFailPoint
        object: testRunner
        arguments:
          session: session0
          failPoint:
            configureFailPoint: failCommand
            mode: {{times: 1}}
            data:
              failCommands: ["{command_name}"]
              {error_data}
      - name: {op_name}
        object: {object_name}
        arguments:
          session: session0
          {op_args}
        result:
          {error_labels}: ["TransientTransactionError"]
      - *{assertion}
      - *abortTransaction
    outcome: *outcome
'''


# Maps from op_name to (command_name, object_name, op_args)
OPS = {
    # Write ops:
    'insertOne': ('insert', 'collection', r'document: {_id: 4}'),
    'insertMany': ('insert', 'collection', r'documents: [{_id: 4}, {_id: 5}]'),
    'updateOne': ('update', 'collection', r'''filter: {_id: 1}
          update: {$inc: {x: 1}}'''),
    'replaceOne': ('update', 'collection', r'''filter: {_id: 1}
          replacement: {y: 1}'''),
    'updateMany': ('update', 'collection', r'''filter: {_id: {$gte: 1}}
          update: {$set: {z: 1}}'''),
    'deleteOne': ('delete', 'collection', r'filter: {_id: 1}'),
    'deleteMany': ('delete', 'collection', r'filter: {_id: {$gte: 1}}'),
    'findOneAndDelete': ('findAndModify', 'collection', r'filter: {_id: 1}'),
    'findOneAndUpdate': ('findAndModify', 'collection', r'''filter: {_id: 1}
          update: {$inc: {x: 1}}
          returnDocument: Before'''),
    'findOneAndReplace': ('findAndModify', 'collection', r'''filter: {_id: 1}
          replacement: {y: 1}
          returnDocument: Before'''),
    # Bulk write insert/update/delete:
    'bulkWrite insert': ('insert', 'collection', r'''requests:
            - name: insertOne
              arguments:
                document: {_id: 1}'''),
    'bulkWrite update': ('update', 'collection', r'''requests:
            - name: updateOne
              arguments:
                filter: {_id: 1}
                update: {$set: {x: 1}}'''),
    'bulkWrite delete': ('delete', 'collection', r'''requests:
            - name: deleteOne
              arguments:
                filter: {_id: 1}'''),
    # Read ops:
    'find': ('find', 'collection', r'filter: {_id: 1}'),
    'countDocuments': ('aggregate', 'collection', r'filter: {}'),
    'aggregate': ('aggregate', 'collection', r'pipeline: []'),
    'distinct': ('distinct', 'collection', r'fieldName: _id'),
    # runCommand:
    'runCommand': (
        'insert',
        r'''database
        command_name: insert''',  # runCommand requires command_name.
        r'''command:
            insert: *collection_name
            documents:
              - _id : 1'''),
}

# Maps from error_name to error_data.
NON_TRANSIENT_ERRORS = {
    'Interrupted': 'errorCode: 11601',
}

# Maps from error_name to error_data.
TRANSIENT_ERRORS = {
    'connection': 'closeConnection: true',
    'ShutdownInProgress': 'errorCode: 91',
}


def create_pin_test(op_name, error_name):
    test_name = 'remain pinned after non-transient'
    assertion = 'assertSessionPinned'
    error_labels = 'errorLabelsOmit'
    command_name, object_name, op_args = OPS[op_name]
    error_data = NON_TRANSIENT_ERRORS[error_name]
    if op_name.startswith('bulkWrite'):
        op_name = 'bulkWrite'
    return TEMPLATE.format(**locals())


def create_unpin_test(op_name, error_name):
    test_name = 'unpin after transient'
    assertion = 'assertSessionUnpinned'
    error_labels = 'errorLabelsContain'
    command_name, object_name, op_args = OPS[op_name]
    error_data = TRANSIENT_ERRORS[error_name]
    if op_name.startswith('bulkWrite'):
        op_name = 'bulkWrite'
    return TEMPLATE.format(**locals())

tests = []
for op_name, error_name in itertools.product(OPS, NON_TRANSIENT_ERRORS):
    tests.append(create_pin_test(op_name, error_name))
for op_name, error_name in itertools.product(OPS, TRANSIENT_ERRORS):
    tests.append(create_unpin_test(op_name, error_name))

print(HEADER)
print(''.join(tests))
