/*
    kCA, a KDE Certification Authority management tool
    Copyright (C) 2013 Felix Tiede <info@pc-tiede.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.

*/

#include "key.h"

#include "opensslexception.h"

#include <QtCore/QByteArray>

#include <QtNetwork/QSsl>
#include <QtNetwork/QSslKey>

#include <openssl/bio.h>
#include <openssl/bn.h>
#include <openssl/buffer.h>
#include <openssl/crypto.h>
#include <openssl/err.h>
#include <openssl/evp.h>
#include <openssl/rsa.h>
#include <openssl/x509.h>

namespace Kca
{
namespace OpenSSL
{
class Key::Private
{
public:
  /** Default constructor from QSslKey. */
  Private(const QSslKey& key) :
      pkey(NULL)
  {
    if (key.isNull() || key.algorithm() != QSsl::Rsa || key.type() != QSsl::PrivateKey)
      return;

    QByteArray der = key.toDer();
    if (der.isNull())
      throw OpenSSLException("Could not encode QSslKey as DER.", __PRETTY_FUNCTION__, __LINE__);

    const unsigned char *buffer =(const unsigned char*) der.constData();

    RSA *rsa = RSA_new();
    if (!rsa) {
      der.clear();
      throw OpenSSLException(ERR_error_string(ERR_get_error(), NULL), __PRETTY_FUNCTION__, __LINE__);
    }

    if (!d2i_RSAPrivateKey(&rsa, &buffer, der.length())) {
      RSA_free(rsa);
      der.clear();
      throw OpenSSLException(ERR_error_string(ERR_get_error(), NULL), __PRETTY_FUNCTION__, __LINE__);
    }

    pkey = EVP_PKEY_new();
    if (!pkey) {
      RSA_free(rsa);
      der.clear();
      throw OpenSSLException(ERR_error_string(ERR_get_error(), NULL), __PRETTY_FUNCTION__, __LINE__);
    }
    if (!EVP_PKEY_set1_RSA(pkey, rsa)) {
      EVP_PKEY_free(pkey);
      pkey = NULL;
      RSA_free(rsa);
      der.clear();
      throw OpenSSLException(ERR_error_string(ERR_get_error(), NULL), __PRETTY_FUNCTION__, __LINE__);
    }
    RSA_free(rsa);
    der.clear();
  }
  /** Default copy constructor. */
  Private(const Private& other) :
      pkey(NULL)
  {
    if (!other.pkey || other.pkey->type != EVP_PKEY_RSA)
      return;

    if (CRYPTO_add(&(other.pkey)->references, 1, CRYPTO_LOCK_EVP_PKEY))
      pkey = other.pkey;
  }

  /** Default destructor. */
  ~Private()
  {
    if (pkey)
      EVP_PKEY_free(pkey);
  }

  /** Assignment operator. */
  Private& operator=(const Private& other)
  {
    if (pkey) {
      EVP_PKEY_free(pkey);
      pkey = NULL;
    }

    if (!other.pkey || other.pkey->type != EVP_PKEY_RSA)
      return *this;

    if (CRYPTO_add(&(other.pkey)->references, 1, CRYPTO_LOCK_EVP_PKEY))
      pkey = other.pkey;

    return *this;
  }

  EVP_PKEY *pkey;
};  // End class Key::Private

};  // End namespace OpenSSL
};  // End namespace Kca

using namespace Kca::OpenSSL;

Key::Key(const QSslKey& key) :
    QSslKey(key),
    e(new Private(key))
{
}

Key::Key(const Key& other) :
    QSslKey(other),
    e(new Private(*(other.e)))
{
}

Key::~Key()
{
  delete e;
}

Key& Key::operator=(const Key& other)
{
  // Check for self-assignment.
  if (this == &other)
    return *this;

  QSslKey::operator=(other);
  *e = *(other.e);

  return *this;
}

const Key Key::generateKeyPair(int length, QSsl::KeyAlgorithm algorithm)
{
  if (algorithm != QSsl::Rsa)
    return Key();

  BIGNUM *bn = BN_new();
  if (!bn)
    throw OpenSSLException(ERR_error_string(ERR_get_error(), NULL), __PRETTY_FUNCTION__, __LINE__);

  RSA *rsa = RSA_new();
  if (!rsa) {
    BN_free(bn);
    throw OpenSSLException(ERR_error_string(ERR_get_error(), NULL), __PRETTY_FUNCTION__, __LINE__);
  }

  if (!BN_set_word(bn, 0x10001)) {
    RSA_free(rsa);
    BN_free(bn);
    return QSslKey();
  }
  if (!RSA_generate_key_ex(rsa, length, bn, NULL)) {
    RSA_free(rsa);
    BN_free(bn);
    throw OpenSSLException(ERR_error_string(ERR_get_error(), NULL), __PRETTY_FUNCTION__, __LINE__);
  }

  // Big number no longer needed
  BN_free(bn);

  EVP_PKEY *pkey = EVP_PKEY_new();
  if (!pkey) {
    RSA_free(rsa);
    throw OpenSSLException(ERR_error_string(ERR_get_error(), NULL), __PRETTY_FUNCTION__, __LINE__);
  }

  if (!EVP_PKEY_set1_RSA(pkey, rsa)) {
    EVP_PKEY_free(pkey);
    RSA_free(rsa);
    throw OpenSSLException(ERR_error_string(ERR_get_error(), NULL), __PRETTY_FUNCTION__, __LINE__);
  }

  // RSA key pair no longer needed
  RSA_free(rsa);

  BIO *memory = BIO_new(BIO_s_mem());
  BUF_MEM *buffer;
  if (!memory) {
    EVP_PKEY_free(pkey);
    throw OpenSSLException(ERR_error_string(ERR_get_error(), NULL), __PRETTY_FUNCTION__, __LINE__);
  }

  if (i2d_PrivateKey_bio(memory, pkey) < 0) {
    BIO_free(memory);
    EVP_PKEY_free(pkey);
    throw OpenSSLException(ERR_error_string(ERR_get_error(), NULL), __PRETTY_FUNCTION__, __LINE__);
  }

  EVP_PKEY_free(pkey);

  if (!BIO_get_mem_ptr(memory, &buffer)) {
    BIO_free(memory);
    throw OpenSSLException(ERR_error_string(ERR_get_error(), NULL), __PRETTY_FUNCTION__, __LINE__);
  }

  QByteArray key(buffer->data, buffer->length);
  Key result(QSslKey(key, QSsl::Rsa, QSsl::Der));

  BIO_free(memory);

  return result;
}

EVP_PKEY* Key::handle() const
{
  CRYPTO_add(&(e->pkey)->references, 1, CRYPTO_LOCK_EVP_PKEY);
  return e->pkey;
}
