#!/usr/bin/env python
#
# Copyright 2016 Confluent Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import argparse, sys
from confluent_kafka import Consumer, KafkaError, KafkaException
from verifiable_client import VerifiableClient

class VerifiableConsumer(VerifiableClient):
    """
    confluent-kafka-python backed VerifiableConsumer class for use with
    Kafka's kafkatests client tests.
    """
    def __init__ (self, conf):
        """
        \p conf is a config dict passed to confluent_kafka.Consumer()
        """
        super(VerifiableConsumer, self).__init__(conf)
        self.conf['on_commit'] = self.on_commit
        self.consumer = Consumer(**conf)
        self.consumed_msgs = 0
        self.consumed_msgs_last_reported = 0
        self.consumed_msgs_at_last_commit = 0
        self.use_auto_commit = False
        self.use_async_commit = False
        self.max_msgs = -1
        self.assignment = []
        self.assignment_dict = dict()


    def find_assignment (self, topic, partition):
        """ Find and return existing assignment based on \p topic and \p partition,
        or None on miss. """
        skey = '%s %d' % (topic, partition)
        return self.assignment_dict.get(skey)


    def send_records_consumed (self, immediate=False):
        """ Send records_consumed, every 100 messages, on timeout,
            or if immediate is set. """
        if (self.consumed_msgs <= self.consumed_msgs_last_reported +
            (0 if immediate else 100)):
            return

        if len(self.assignment) == 0:
            return

        d = {'name': 'records_consumed',
             'count': self.consumed_msgs - self.consumed_msgs_last_reported,
             'partitions': []}

        for a in self.assignment:
            if a.min_offset == -1:
                # Skip partitions that havent had any messages since last time.
                # This is to circumvent some minOffset checks in kafkatest.
                continue
            d['partitions'].append(a.to_dict())
            a.min_offset = -1

        self.send(d)
        self.consumed_msgs_last_reported = self.consumed_msgs


    def send_assignment (self, evtype, partitions):
        """ Send assignment update, \p evtype is either 'assigned' or 'revoked' """
        d = { 'name': 'partitions_' + evtype,
              'partitions': [{'topic': x.topic, 'partition': x.partition} for x in partitions]}
        self.send(d)


    def on_assign (self, consumer, partitions):
        """ Rebalance on_assign callback """
        old_assignment = self.assignment
        self.assignment = [AssignedPartition(p.topic, p.partition) for p in partitions]
        # Move over our last seen offsets so that we can report a proper
        # minOffset even after a rebalance loop.
        for a in old_assignment:
            b = self.find_assignment(a.topic, a.partition)
            b.min_offset = a.min_offset

        self.assignment_dict = {a.skey: a for a in self.assignment}
        self.send_assignment('assigned', partitions)

    def on_revoke (self, consumer, partitions):
        """ Rebalance on_revoke callback """
        # Send final consumed records prior to rebalancing to make sure
        # latest consumed is in par with what is going to be committed.
        self.send_records_consumed(immediate=True)
        self.assignment = list()
        self.assignment_dict = dict()
        self.send_assignment('revoked', partitions)
        self.do_commit(immediate=True)


    def on_commit (self, err, partitions):
        """ Offsets Committed callback """
        if err is not None and err.code() == KafkaError._NO_OFFSET:
            self.dbg('on_commit(): no offsets to commit')
            return

        # Report consumed messages to make sure consumed position >= committed position
        self.send_records_consumed(immediate=True)

        d = {'name': 'offsets_committed',
             'offsets': []}

        if err is not None:
            d['success'] = False
            d['error'] = str(err)
        else:
            d['success'] = True
            d['error'] = ''

        for p in partitions:
            pd = {'topic': p.topic, 'partition': p.partition,
                  'offset': p.offset, 'error': str(p.error)}
            d['offsets'].append(pd)

        self.send(d)


    def do_commit (self, immediate=False, async=None):
        """ Commit every 1000 messages or whenever there is a consume timeout
            or immediate. """
        if (self.use_auto_commit or
            self.consumed_msgs_at_last_commit + (0 if immediate else 1000) >
            self.consumed_msgs):
            return

        # Make sure we report consumption before commit,
        # otherwise tests may fail because of commit > consumed
        if self.consumed_msgs_at_last_commit < self.consumed_msgs:
            self.send_records_consumed(immediate=True)

        if async is None:
            async_mode = self.use_async_commit
        else:
            async_mode = async

        self.dbg('Committing %d messages (Async=%s)' %
                 (self.consumed_msgs - self.consumed_msgs_at_last_commit,
                  async_mode))

        try:
            self.consumer.commit(async=async_mode)
        except KafkaException as e:
            if e.args[0].code() == KafkaError._WAIT_COORD:
                self.dbg('Ignoring commit failure, still waiting for coordinator')
            elif e.args[0].code() == KafkaError._NO_OFFSET:
                self.dbg('No offsets to commit')
            else:
                raise

        self.consumed_msgs_at_last_commit = self.consumed_msgs


    def msg_consume (self, msg):
        """ Handle consumed message (or error event) """
        if msg.error():
            if msg.error().code() == KafkaError._PARTITION_EOF:
                # ignore EOF
                pass
            else:
                self.err('Consume failed: %s' % msg.error(), term=True)
            return

        if False:
            self.dbg('Read msg from %s [%d] @ %d' % \
                     (msg.topic(), msg.partition(), msg.offset()))

        if self.max_msgs >= 0 and self.consumed_msgs >= self.max_msgs:
            return # ignore extra messages

        # Find assignment.
        a = self.find_assignment(msg.topic(), msg.partition())
        if a is None:
            self.err('Received message on unassigned partition %s [%d] @ %d' %
                     (msg.topic(), msg.partition(), msg.offset()), term=True)

        a.consumed_msgs += 1
        if a.min_offset == -1:
            a.min_offset = msg.offset()
        if a.max_offset < msg.offset():
            a.max_offset = msg.offset()

        self.consumed_msgs += 1

        self.send_records_consumed(immediate=False)
        self.do_commit(immediate=False)


class AssignedPartition(object):
    """ Local state container for assigned partition. """
    def __init__ (self, topic, partition):
        super(AssignedPartition, self).__init__()
        self.topic = topic
        self.partition = partition
        self.skey = '%s %d' % (self.topic, self.partition)
        self.consumed_msgs = 0
        self.min_offset = -1
        self.max_offset = 0

    def to_dict (self):
        """ Return a dict of this partition's state """
        return {'topic': self.topic, 'partition': self.partition,
                'minOffset': self.min_offset, 'maxOffset': self.max_offset}


        

    




        

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Verifiable Python Consumer')
    parser.add_argument('--topic', action='append', type=str, required=True)
    parser.add_argument('--group-id', dest='group.id', required=True)
    parser.add_argument('--broker-list', dest='bootstrap.servers', required=True)
    parser.add_argument('--session-timeout', type=int, dest='session.timeout.ms', default=6000)
    parser.add_argument('--enable-autocommit', action='store_true', dest='enable.auto.commit', default=False)
    parser.add_argument('--max-messages', type=int, dest='max_messages', default=-1)
    parser.add_argument('--assignment-strategy', dest='partition.assignment.strategy')
    parser.add_argument('--reset-policy', dest='topic.auto.offset.reset', default='earliest')
    parser.add_argument('--consumer.config', dest='consumer_config')
    args = vars(parser.parse_args())

    conf = {'broker.version.fallback': '0.9.0',
            'default.topic.config': dict()}

    VerifiableClient.set_config(conf, args)

    vc = VerifiableConsumer(conf)
    vc.use_auto_commit = args['enable.auto.commit']
    vc.max_msgs = args['max_messages']

    vc.dbg('Using config: %s' % conf)

    vc.dbg('Subscribing to %s' % args['topic'])
    vc.consumer.subscribe(args['topic'],
                          on_assign=vc.on_assign, on_revoke=vc.on_revoke)


    try:
        while vc.run:
            msg = vc.consumer.poll(timeout=1.0)
            if msg is None:
                # Timeout.
                # Try reporting consumed messages
                vc.send_records_consumed(immediate=True)
                # Commit every poll() timeout instead of on every message.
                # Also commit on every 1000 messages, whichever comes first.
                vc.do_commit(immediate=True)
                continue

            # Handle message (or error event)
            vc.msg_consume(msg)

    except KeyboardInterrupt:
        pass

    vc.dbg('Closing consumer')
    vc.send_records_consumed(immediate=True)
    if not vc.use_auto_commit:
        vc.do_commit(immediate=True, async=False)

    vc.consumer.close()

    vc.send({'name': 'shutdown_complete'})

    vc.dbg('All done')
