/* Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. */
#ifdef __cplusplus
extern "C" {
#include "libavutil/channel_layout.h"
#include "libavutil/common.h"
#include "libavutil/opt.h"
#include "avcodec.h"
#include "internal.h"
#include "libavutil/log.h"
#include "AACDecoder.h"
}
#endif

#define DECODER_MAX_CHANNELS 10
#define DECODER_BUFFSIZE 2048 * sizeof(INT_PCM) //2 frames of data

#define TT_MHA_RAW 60

typedef int32_t status_t;

using namespace piano;

typedef struct XHEAACDecContext
{
    const AVClass *classptr;
    void *handle;
    uint8_t *decoder_buffer;
    int decoder_buffer_size;
    int cache_channel_cnt;
    uint64_t cache_channel_layout;
} XHEAACDecContext;

static const char *get_item_name(void *ptr)
{
    return (*(AVClass **) ptr)->class_name;
}

#define AD AV_OPT_FLAG_AUDIO_PARAM | AV_OPT_FLAG_DECODING_PARAM
static const AVOption libxheaac_decode_options[] = {};

static const AVClass libxheaac_decode_class = {
    .class_name = "xheaac_decode",
    .item_name = av_default_item_name,
    .option = NULL,
    .version = LIBAVUTIL_VERSION_INT,
};

static int get_stream_info(AVCodecContext *avctx)
{
    XHEAACDecContext *s = (XHEAACDecContext *)avctx->priv_data;
    void *info = getDecoderStreamInfo(s->handle);

    if (!info)
    {
        av_log(avctx, AV_LOG_ERROR, "Unable to get stream info\n");
        return AVERROR_UNKNOWN;
    }
    int sampleRate = findSampleRate(info);
    if (sampleRate <= 0)
    {
        av_log(avctx, AV_LOG_ERROR, "Stream info not initialized\n");
        return AVERROR_UNKNOWN;
    }
    avctx->sample_rate = sampleRate;
    avctx->frame_size = findFrameSize(info);
    avctx->channels = findNumChannels(info);

    if (avctx->channels != s->cache_channel_cnt) {
        s->cache_channel_cnt = avctx->channels;
        if (!avctx->request_channel_layout) {
            avctx->request_channel_layout = AV_CH_LAYOUT_5POINT1_POINT4;
        }
        // If we have enough channels available to use the requested layout, use it...
        if (avctx->channels >= av_get_channel_layout_nb_channels(avctx->request_channel_layout)) {
            s->cache_channel_layout = avctx->request_channel_layout;
        }
        else if (avctx->channels >= 2) {
            s->cache_channel_layout = AV_CH_LAYOUT_STEREO;
        }
        else {
            s->cache_channel_layout = 0;
        }
    }

    avctx->channel_layout = s->cache_channel_layout;

    return 0;
}

static av_cold int libxheaac_decode_close(AVCodecContext *avctx)
{
    XHEAACDecContext *s = (XHEAACDecContext *)avctx->priv_data;
    av_log(avctx, AV_LOG_WARNING, "close call");

    if (s->handle)
    {
        deinitializeDecoder(s->handle);
    }
    av_freep(&s->decoder_buffer);

    return 0;
}
static av_cold int libxheaac_decode_init(AVCodecContext *avctx)
{
    XHEAACDecContext *s = (XHEAACDecContext *)avctx->priv_data;
    av_log(avctx, AV_LOG_WARNING, "init call");

    status_t err;

    s->handle = createDecoder();

    if (!s->handle)
    {
        av_log(avctx, AV_LOG_ERROR, "Error creating decoder instance\n");
        return AVERROR_UNKNOWN;
    }

    int32_t nrOfLayers = 1;
    //type should be determined by demuxer
    err = initializeDecoder(s->handle, TT_MHA_RAW, nrOfLayers);

    if (err != AAC_DEC_OK)
    {
        av_log(avctx, AV_LOG_ERROR, "Error initializing decoder\n");
        deinitializeDecoder(s->handle);
        return AVERROR_UNKNOWN;
    }

    err = configureRaw(s->handle, &avctx->extradata,
                            (uint32_t *)&avctx->extradata_size);

    if (err != AAC_DEC_OK)
    {
        av_log(avctx, AV_LOG_ERROR, "Unable to set extradata\n");
        deinitializeDecoder(s->handle);
        return AVERROR_INVALIDDATA;
    }

    avctx->sample_fmt = AV_SAMPLE_FMT_S16;

    s->decoder_buffer_size = DECODER_BUFFSIZE * DECODER_MAX_CHANNELS;
    s->decoder_buffer = (uint8_t *)av_malloc(s->decoder_buffer_size);
    s->cache_channel_cnt = 0;
    s->cache_channel_layout = 0;

    if (!s->decoder_buffer)
    {
        s->decoder_buffer_size = 0;
        s->decoder_buffer = NULL;
        deinitializeDecoder(s->handle);
        return AVERROR(ENOMEM);
    }

    return 0;
}

static int libxheaac_decode_frame(AVCodecContext *avctx, void *data,
                                      int *got_frame_ptr, AVPacket *avpkt)
{
    XHEAACDecContext *s = (XHEAACDecContext *)avctx->priv_data;
    AVFrame *frame = (AVFrame *)data;

    int ret = -1;
    status_t err;
    unsigned int avail_bytes = avpkt->size;

    err = fillDecoder(s->handle, &avpkt->data, (const unsigned int*)&avpkt->size, &avail_bytes);

    if (err != AAC_DEC_OK)
    {
        av_log(avctx, AV_LOG_ERROR, "Error filling decoder [%x]\n", err);
        return AVERROR_INVALIDDATA;
    }

    err = decodeFrame(s->handle, (INT_PCM *)s->decoder_buffer, s->decoder_buffer_size / sizeof(INT_PCM), 0);

    if (err == AAC_DEC_NOT_ENOUGH_BITS)
    {
        av_log(avctx, AV_LOG_WARNING,"not enough bits\n");
        ret = avpkt->size - avail_bytes;
        return ret;
    }

    if (err != AAC_DEC_OK)
    {
        av_log(avctx, AV_LOG_ERROR,"error decoding frame %x\n", err);
        ret = AVERROR_UNKNOWN;
        return ret;
    }

    if ((ret = get_stream_info(avctx)) < 0)
    {
        return ret;
    }

    frame->nb_samples = avctx->frame_size;
    frame->format = AV_SAMPLE_FMT_S16;

    if ((ret = ff_get_buffer(avctx, frame, 0)) < 0)
    {
        return ret;
    }

    memcpy(frame->extended_data[0], s->decoder_buffer,
           avctx->channels * avctx->frame_size *
               av_get_bytes_per_sample(avctx->sample_fmt));

    *got_frame_ptr = 1;

    return avpkt->size - avail_bytes;
}

static av_cold void libxheaac_decode_flush(AVCodecContext *avctx)
{
    XHEAACDecContext *s = (XHEAACDecContext *)avctx->priv_data;
    if (!s->handle)
        return;
    //todo - flushDecoder(s->handle);
}

AVCodec ff_libxheaac_decoder = {
    .name                   = "libxheaac_decoder",
    .long_name              = "libxheaac_decoder",
    .type = AVMEDIA_TYPE_AUDIO,
    .id = AV_CODEC_ID_XHEAAC,
    .capabilities = AV_CODEC_CAP_DR1 | AV_CODEC_CAP_CHANNEL_CONF,
    .supported_framerates   = NULL,
    .pix_fmts               = NULL,
    .supported_samplerates  = NULL,
    .sample_fmts            = (const enum AVSampleFormat[]) {
        AV_SAMPLE_FMT_FLTP, AV_SAMPLE_FMT_NONE
    },
    .channel_layouts        = 0,
    .max_lowres             = 0,
    .priv_class             = &libxheaac_decode_class,
    .profiles               = NULL,
    .wrapper_name           = "libamazon_codec_decoder",
    .priv_data_size         = sizeof(XHEAACDecContext),
    .next                   = NULL,
    .init_thread_copy       = NULL,
    .update_thread_context  = NULL,
    .defaults               = NULL,
    .init_static_data       = NULL,
    .init                   = libxheaac_decode_init,
    .encode_sub             = NULL,
    .encode2                = NULL,
    .decode                 = libxheaac_decode_frame,
    .close                  = libxheaac_decode_close,
    .send_frame             = NULL,
    .receive_packet         = NULL,
    .receive_frame          = NULL,
    .flush                  = libxheaac_decode_flush,
    .caps_internal          = FF_CODEC_CAP_INIT_THREADSAFE |
                                FF_CODEC_CAP_INIT_CLEANUP,
    .bsfs                   = NULL,
    .hw_configs             = NULL,
};
