#include "ac3filter.h"

#include <dsound.h>
#include <ks.h>
#include <ksmedia.h>
#include <mmreg.h>


const uint8_t spdif_header[8] = { 0x72, 0xf8, 0x1f, 0x4e, 0x01, 0x00, 0x00, 0x38 }; 
const int ds_channels[16] = 
{
  SPEAKER_FRONT_LEFT   | SPEAKER_FRONT_RIGHT,  // double mono as stereo
  SPEAKER_FRONT_CENTER,
  SPEAKER_FRONT_LEFT   | SPEAKER_FRONT_RIGHT,
  SPEAKER_FRONT_LEFT   | SPEAKER_FRONT_CENTER | SPEAKER_FRONT_RIGHT,
  SPEAKER_FRONT_LEFT   | SPEAKER_FRONT_RIGHT  | SPEAKER_BACK_CENTER,
  SPEAKER_FRONT_LEFT   | SPEAKER_FRONT_CENTER | SPEAKER_FRONT_RIGHT   | SPEAKER_BACK_CENTER,
  SPEAKER_FRONT_LEFT   | SPEAKER_FRONT_RIGHT  | SPEAKER_BACK_LEFT     | SPEAKER_BACK_RIGHT,
  SPEAKER_FRONT_LEFT   | SPEAKER_FRONT_CENTER | SPEAKER_FRONT_RIGHT   | SPEAKER_BACK_LEFT    | SPEAKER_BACK_RIGHT,

  SPEAKER_FRONT_LEFT   | SPEAKER_FRONT_RIGHT  | SPEAKER_LOW_FREQUENCY,  // double mono as stereo
  SPEAKER_FRONT_CENTER | SPEAKER_LOW_FREQUENCY,
  SPEAKER_FRONT_LEFT   | SPEAKER_FRONT_RIGHT  | SPEAKER_LOW_FREQUENCY,
  SPEAKER_FRONT_LEFT   | SPEAKER_FRONT_CENTER | SPEAKER_FRONT_RIGHT   | SPEAKER_LOW_FREQUENCY,
  SPEAKER_FRONT_LEFT   | SPEAKER_FRONT_RIGHT  | SPEAKER_BACK_CENTER   | SPEAKER_LOW_FREQUENCY,
  SPEAKER_FRONT_LEFT   | SPEAKER_FRONT_CENTER | SPEAKER_FRONT_RIGHT   | SPEAKER_BACK_CENTER  | SPEAKER_LOW_FREQUENCY,
  SPEAKER_FRONT_LEFT   | SPEAKER_FRONT_RIGHT  | SPEAKER_BACK_LEFT     | SPEAKER_BACK_RIGHT   | SPEAKER_LOW_FREQUENCY,
  SPEAKER_FRONT_LEFT   | SPEAKER_FRONT_CENTER | SPEAKER_FRONT_RIGHT   | SPEAKER_BACK_LEFT    | SPEAKER_BACK_RIGHT     | SPEAKER_LOW_FREQUENCY
};

const channel_order_t ds_order[16] = 
{
  { CH_L,   CH_R,   0,      0,      0,     0 },
  { CH_C,   0,      0,      0,      0,     0 },
  { CH_L,   CH_R,   0,      0,      0,     0 },
  { CH_L,   CH_R,   CH_C,   0,      0,     0 },
  { CH_L,   CH_R,   CH_S,   0,      0,     0 },
  { CH_L,   CH_R,   CH_C,   CH_S,   0,     0 },
  { CH_L,   CH_R,   CH_SL,  CH_SR,  0,     0 },
  { CH_L,   CH_R,   CH_C,   CH_SL,  CH_SR, 0 },

  { CH_L,   CH_R,   CH_LFE, 0,      0,     0 },
  { CH_C,   CH_LFE, 0,      0,      0,     0 },
  { CH_L,   CH_R,   CH_LFE, 0,      0,     0 },
  { CH_L,   CH_R,   CH_C,   CH_LFE, 0,     0 },
  { CH_L,   CH_R,   CH_LFE, CH_S,   0,     0 },
  { CH_L,   CH_R,   CH_C,   CH_LFE, CH_S,  0 },
  { CH_L,   CH_R,   CH_LFE, CH_SL,  CH_SR, 0 },
  { CH_L,   CH_R,   CH_C,   CH_LFE, CH_SL, CH_SR }
};




AC3Filter::AC3Filter(TCHAR *tszName, LPUNKNOWN punk, HRESULT *phr) :
  CTransformFilter(tszName, punk, CLSID_AC3Filter)
{
  DbgLog((LOG_TRACE, 3, "AC3Filter(%x)::AC3Filter", this));
  ASSERT(tszName);
  ASSERT(phr);

  sample = 0;
  sample_buffer = 0;
  current_block = 0;
  input_sample_rate = 48000;

  set_speakers(default_speakers());
  load_matrix();
  load_params();  // default matrix can be overwritten
}


AC3Filter::~AC3Filter()
{
  DbgLog((LOG_TRACE, 3, "AC3Filter(%x)::~AC3Filter", this));
  if (sample) 
    sample->Release();
  save_params();
  save_matrix();
}

CUnknown * WINAPI 
AC3Filter::CreateInstance(LPUNKNOWN punk, HRESULT *phr) 
{
  DbgLog((LOG_TRACE, 3, "AC3Filter::CreateInstance"));
  try 
  {
    AC3Filter *pobj = new AC3Filter("AC3Filter", punk, phr);
    if (!pobj) 
      *phr = E_OUTOFMEMORY;
    return pobj;
  }
  catch (...)
  {
    *phr = E_UNEXPECTED;
    return 0;
  }
  return 0;
}

STDMETHODIMP 
AC3Filter::NonDelegatingQueryInterface(REFIID riid, void **ppv)
{
  CheckPointer(ppv,E_POINTER);

  if (riid == IID_IAC3Filter)
    return GetInterface((IAC3Filter *) this, ppv);
  else if (riid == IID_ISpecifyPropertyPages)
    return GetInterface((ISpecifyPropertyPages *) this, ppv);
  else
    return CTransformFilter::NonDelegatingQueryInterface(riid, ppv);
}

void
AC3Filter::new_sample()
{
  DbgLog((LOG_TRACE, 3, "AC3Filter::new_sample"));
  if (sample) sample->Release();

  if FAILED(m_pOutput->GetDeliveryBuffer(&sample, 0, 0, 0))
    throw "Get delivery buffer failed";
  ASSERT(sample->GetSize() >= buffer_size);

  sample->GetPointer((BYTE**)&sample_buffer);
  if (need_to_timestamp)
  {
    need_to_timestamp = false;
    sample->SetTime(&sample_time, 0);
  }
  else
    sample->SetTime(0, 0);

  sample->SetMediaTime(0, 0);
  sample->SetSyncPoint(true);
  sample->SetActualDataLength(buffer_size);
}

void 
AC3Filter::frame()
{
  if (!IS_MODE_SPDIF(speakers))
  {
    Decoder::frame();
    return;
  }

  memset(sample_buffer, 0, buffer_size);
  memcpy(sample_buffer, spdif_header, 8);
  sample_buffer[4] = bsi.frame_size * 8;
  _swab((char *)buf, (char *)(sample_buffer+4), (bsi.frame_size + 1) & ~1);
  m_pOutput->Deliver(sample);

#ifdef DEBUG
    REFERENCE_TIME begin_time;
    REFERENCE_TIME media_time;
    REFERENCE_TIME end_time;

    if FAILED(sample->GetTime(&begin_time, &end_time))
      begin_time = -10000;

    if FAILED(sample->GetMediaTime(&media_time, &end_time))
      media_time = -10000;

    DbgLog((LOG_TRACE, 3, "Sample sent (spdif):\tsize: %i\tdiscont: %i\t-time: %i\tmedia time: %i\terr: %i",
      sample->GetActualDataLength(), 
      sample->IsDiscontinuity() == S_OK, 
      int(__int64(begin_time)/10000),
      int(__int64(media_time)/10000),
      errors));
#endif

  new_sample();
}

void 
AC3Filter::block()
{
  int nchannels = MODE_NCHANS(speakers);
  for (int i = 0; i < 256; i++)        
    for (int j = 0; j < nchannels; j++)
      sample_buffer[i*nchannels+j + current_block * 256 * nchannels] = int(min(32000.0, samples[j][i] * 32000));
  next_block();
}

void 
AC3Filter::next_block()
{
  current_block++;
  if (current_block >= blocks_per_sample)
  {
    m_pOutput->Deliver(sample);

#ifdef DEBUG
    REFERENCE_TIME begin_time;
    REFERENCE_TIME media_time;
    REFERENCE_TIME end_time;

    if FAILED(sample->GetTime(&begin_time, &end_time))
      begin_time = -10000;

    if FAILED(sample->GetMediaTime(&media_time, &end_time))
      media_time = -10000;

    DbgLog((LOG_TRACE, 3, "Sample sent:\tsize: %i\tdiscont: %i\t-time: %i\tmedia time: %i\terr: %i",
      sample->GetActualDataLength(), 
      sample->IsDiscontinuity() == S_OK, 
      int(__int64(begin_time)/10000),
      int(__int64(media_time)/10000),
      errors));
#endif

    new_sample();
    current_block = 0;
  }
  if (!sample) 
    new_sample();
}



HRESULT 
AC3Filter::Transform(IMediaSample *pIn, IMediaSample *pOut)
{
try
{
  DbgLog((LOG_TRACE, 3, "AC3Filter::Transform"));
  CAutoLock data_flow_lock(&data_flow);

  BYTE *pSourceBuffer;
  REFERENCE_TIME end_time;

  pIn->GetPointer(&pSourceBuffer);

  if FAILED(pIn->GetTime(&sample_time, &end_time))
    need_to_timestamp = false;
  else
    need_to_timestamp = true;

  if (pIn->IsDiscontinuity() == S_OK)
  {
    Decoder::reset();
    new_sample();
    sample->SetDiscontinuity(true);
  }

#ifdef DEBUG
  REFERENCE_TIME begin_time;
  REFERENCE_TIME media_time;
  REFERENCE_TIME end_time2;

  if FAILED(pIn->GetTime(&begin_time, &end_time2))
    begin_time = -10000;

  if FAILED(pIn->GetMediaTime(&media_time, &end_time2))
    media_time = -10000;

  DbgLog((LOG_TRACE, 3, "Input sample:\tsize: %i\tdiscont: %i\t+time: %i\tmedia time: %i\terr: %i",
    pIn->GetActualDataLength(), 
    pIn->IsDiscontinuity() == S_OK, 
    int(__int64(begin_time)/10000),
    int(__int64(media_time)/10000),
      errors));
#endif

  decode(pSourceBuffer, pIn->GetActualDataLength());
  return S_FALSE;
}
catch (char *s)
{
  DbgLog((LOG_TRACE, 3, "Transform: %s", s));
  return E_UNEXPECTED;
}
catch (...)
{
  DbgLog((LOG_TRACE, 3, "Transform: some strange happen"));
  return E_UNEXPECTED;
}
}


HRESULT 
AC3Filter::CheckInputType(const CMediaType *mtIn)
{
  DbgLog((LOG_TRACE, 3, "AC3Filter(%x)::CheckInputType", this));

  ASSERT(mtIn);
  if (*mtIn->Type() == MEDIATYPE_MPEG2_PES && *mtIn->Subtype() == MEDIASUBTYPE_DOLBY_AC3)
    return NOERROR;

  if (*mtIn->Type() != MEDIATYPE_Audio)
    return E_INVALIDARG;

  if (*mtIn->Subtype() == MEDIASUBTYPE_DOLBY_AC3)
    return NOERROR;

  if (*mtIn->FormatType() == FORMAT_WaveFormatEx && ((WAVEFORMATEX *)mtIn->Format())->wFormatTag == 0x2000) 
    return NOERROR;

  return E_INVALIDARG;
} 



HRESULT 
AC3Filter::CheckTransform(const CMediaType *mtIn, const CMediaType *mtOut)
{
  DbgLog((LOG_TRACE, 3, "AC3Filter(%x)::CheckTransform", this));
  ASSERT(mtIn);
  ASSERT(mtOut);

  HRESULT hr;

  if FAILED(hr = CheckInputType(mtIn))
    return hr;

  if (*mtOut->FormatType() != FORMAT_WaveFormatEx)
    return E_INVALIDARG;

  if (mtOut->FormatLength() < sizeof(WAVEFORMATEX))
          return E_INVALIDARG;

  if (memcmp(mtOut->Format(), &format, sizeof(WAVEFORMATEX)))
    return E_INVALIDARG;

  return NOERROR;
}


HRESULT 
AC3Filter::DecideBufferSize(IMemAllocator *pAlloc, ALLOCATOR_PROPERTIES *pProperties)
{
  DbgLog((LOG_TRACE, 3, "AC3Filter(%x)::DecideBufferSize", this));

  // Is the input pin connected
  if (m_pInput->IsConnected() == FALSE)
    return E_UNEXPECTED;

  ASSERT(pAlloc);
  ASSERT(pProperties);
  HRESULT hr = NOERROR;

  pProperties->cBuffers = 10;
  pProperties->cbBuffer = max_buffer_size;

  ASSERT(pProperties->cbBuffer);

  ALLOCATOR_PROPERTIES Actual;
  hr = pAlloc->SetProperties(pProperties,&Actual);
  if FAILED(hr)
    return hr;

  ASSERT(Actual.cBuffers == 10);

  if (pProperties->cBuffers > Actual.cBuffers ||
      pProperties->cbBuffer > Actual.cbBuffer)
    return E_FAIL;

  return NOERROR;
}


HRESULT 
AC3Filter::GetMediaType(int iPosition, CMediaType *pMediaType)
{
  DbgLog((LOG_TRACE, 3, "AC3Filter(%x)::GetMediaType #%i", this, iPosition));
  ASSERT(pMediaType);

  if (m_pInput->IsConnected() == FALSE)
    return E_UNEXPECTED;

  if (iPosition < 0)
    return E_INVALIDARG;

  if (iPosition > 0)
    return VFW_S_NO_MORE_ITEMS;

  WAVEFORMATEXTENSIBLE *wfx = (WAVEFORMATEXTENSIBLE*) pMediaType->AllocFormatBuffer(sizeof(WAVEFORMATEXTENSIBLE));
  if (!wfx)  return(E_OUTOFMEMORY);
  memcpy(wfx, &format, sizeof(WAVEFORMATEXTENSIBLE));
  pMediaType->SetType(&MEDIATYPE_Audio);
  pMediaType->SetSubtype(&MEDIASUBTYPE_PCM);
  pMediaType->SetFormatType(&FORMAT_WaveFormatEx);

  return NOERROR;
}

HRESULT 
AC3Filter::SetMediaType(PIN_DIRECTION direction, const CMediaType *pmt)
{
  DbgLog((LOG_TRACE, 3, "AC3Filter(%x)::SetMediaType", this));
  ASSERT(pmt);

  if (direction == PINDIR_INPUT && *pmt->FormatType() == FORMAT_WaveFormatEx)
  {
    is_pes = ((*pmt->Type()) == MEDIATYPE_MPEG2_PES);
    input_sample_rate = ((WAVEFORMATEX*)(pmt->Format()))->nSamplesPerSec;
    format.Format.nSamplesPerSec = input_sample_rate;
    format.Format.nAvgBytesPerSec = input_sample_rate * format.Format.nBlockAlign;
  }

  return NOERROR;
}








acmod_t      
AC3Filter::default_speakers()
{
  acmod_t result = MODE_STEREO;

  IDirectSound *ds;
  if FAILED(DirectSoundCreate(0, &ds, 0))
  {
    DbgLog((LOG_TRACE, 3, "Cannot create DirectSound object"));
    return result;
  }

  HWND hWnd = GetForegroundWindow();
  if (!hWnd) hWnd = GetDesktopWindow();

  if FAILED(ds->SetCooperativeLevel(hWnd, DSSCL_PRIORITY))
  {
    DbgLog((LOG_TRACE, 3, "Cannot set cooperation level")); 
    ds->Release();
    return result;
  }

  DWORD Speakers = 0;
  if FAILED(ds->GetSpeakerConfig(&Speakers))
  {
    DbgLog((LOG_TRACE, 3, "Cannot get speaker config"));
    ds->Release();
    return result;
  }
  ds->Release();

  switch (DSSPEAKER_CONFIG(Speakers))
  {
    case DSSPEAKER_5POINT1:    return MODE_5_1;    
    case DSSPEAKER_HEADPHONE:  return MODE_STEREO; 
    case DSSPEAKER_MONO:       return MODE_MONO;   
    case DSSPEAKER_QUAD:       return MODE_QUADRO; 
    case DSSPEAKER_STEREO:     return MODE_STEREO; 
    default:
      return result;
  }
}

STDMETHODIMP
AC3Filter::set_speakers(acmod_t _speakers)
{
  if (!IsStopped())
    return E_UNEXPECTED;

  CAutoLock data_flow_lock(&data_flow);

  // for rollback
  acmod_t old_speakers    = speakers;
  DWORD   old_buffer_size = buffer_size;
  WAVEFORMATEXTENSIBLE old_format;
  memcpy(&old_format, &format, sizeof(WAVEFORMATEXTENSIBLE));


  /////////////////////////////////////////////////////////
  // Determine number of speakers in system if not specified
  //

  DbgLog((LOG_TRACE, 3, "get acmod: %i", _speakers));
  speakers = _speakers;
  if (IS_MODE_ERROR(speakers))
    speakers = default_speakers();

  if (!IS_MODE_SPDIF(speakers))
  {
    // non-SPDIF format
    int nchannels = MODE_NCHANS(speakers);
    ZeroMemory(&format, sizeof(WAVEFORMATEXTENSIBLE));
    format.Format.wFormatTag = nchannels <= 2? WAVE_FORMAT_PCM: WAVE_FORMAT_EXTENSIBLE;
    format.Format.nChannels = nchannels;
    format.Format.nSamplesPerSec = input_sample_rate;
    format.Format.wBitsPerSample = 16;
    format.Format.nBlockAlign = 2*nchannels;
    format.Format.nAvgBytesPerSec = format.Format.nSamplesPerSec * format.Format.nBlockAlign;
    format.Format.cbSize = 22;
    format.Samples.wValidBitsPerSample = 16;
    format.SubFormat = KSDATAFORMAT_SUBTYPE_PCM;
    format.dwChannelMask = ds_channels[speakers&31];
    buffer_size = 256*2*nchannels*blocks_per_sample;
    DbgLog((LOG_TRACE, 3, "trying to set output mode: acmod: %i, sample_rate: %i, n_channels: %i, blocks_per_sample: %i, buffer_size: %i", speakers, input_sample_rate, nchannels, blocks_per_sample, buffer_size));
  }
  else
  {
    // SPDIF format
    int nchannels = MODE_NCHANS(speakers);
    ZeroMemory(&format, sizeof(WAVEFORMATEXTENSIBLE));
    format.Format.wFormatTag = WAVE_FORMAT_DOLBY_AC3_SPDIF;
    format.Format.nChannels = 2;
    format.Format.nSamplesPerSec = input_sample_rate;
    format.Format.wBitsPerSample = 16;
    format.Format.nBlockAlign = 4;
    format.Format.nAvgBytesPerSec = format.Format.nSamplesPerSec * format.Format.nBlockAlign;
    format.Format.cbSize = 0;
    buffer_size = 0x1800;
    DbgLog((LOG_TRACE, 3, "trying to set output mode: spdif"));
  }


  if (m_pOutput)
    if (m_pOutput->IsConnected())
    {
      DbgLog((LOG_TRACE, 3, "Reconnecting..."));

      CMediaType media_type;
      GetMediaType(0, &media_type);

      IGraphConfig *pConfig;
//      m_pOutput->BreakConnect();
      m_pGraph->QueryInterface(IID_IGraphConfig, (void **)&pConfig);
      HRESULT hr = pConfig->Reconnect(
              m_pOutput,  // Start from output pin.
              NULL,       // Search downstream for a suitable input pin.
              &media_type,// First connection type
              NULL,       // Use this filter.
              NULL, 
              0); 
      pConfig->Release();

#ifdef DEBUG
      switch (hr)
      {
        case E_INVALIDARG:
          DbgLog((LOG_TRACE, 3, "Reconnect: E_INVALIDARG")); break;
        case E_NOINTERFACE:
          DbgLog((LOG_TRACE, 3, "Reconnect: E_NOINTERFACE")); break;
        case VFW_E_CANNOT_CONNECT:
          DbgLog((LOG_TRACE, 3, "Reconnect: VFW_E_CANNOT_CONNECT")); break;
        case VFW_E_STATE_CHANGED:
          DbgLog((LOG_TRACE, 3, "Reconnect: VFW_E_STATE_CHANGED")); break;
        case S_OK:
          DbgLog((LOG_TRACE, 3, "Reconnecting ok.")); break;
        default:
          DbgLog((LOG_TRACE, 3, "return code: %x", hr)); break;
      }
#endif
      // rollback
      if FAILED(hr)
      {
        speakers = old_speakers;
        buffer_size = old_buffer_size;
        memcpy(&format, &old_format, sizeof(WAVEFORMATEXTENSIBLE));
        return S_OK;
      }
    }

  current_block = 0;
  set_acmod(speakers, ds_order[speakers & 31]);
  Decoder::reset();
  return S_OK;
}

