/** @file mariadb_ps.cpp
 *
 * Test program to test MariaDB PS protocol and bulk operations.
 * Originally developed to reproduce https://jira.mariadb.org/browse/MDEV-23872.
 *
 * Requires C++11 capable compiler and libmariadb.
 * Compile with:
 *
 *   g++ -o mariadb_ps mariadb_ps.cpp -lmariadb -lpthread
 *
 * Help
 *
 *   ./mariadb_ps help
 */

#include <mariadb/mysql.h>

#include <cassert>
#include <cstring>

#include <algorithm>
#include <array>
#include <chrono>
#include <exception>
#include <iostream>
#include <string>
#include <thread>
#include <vector>

namespace mps
{
    struct hostport
    {
        std::string host;
        int port;
    };

    /** @class connection
     *
     * Encapsulates MariaDB connection.
     */
    class connection
    {
    public:
        /**
         * Constructor.
         *
         * @param addr Address string in form of <host>:<port>
         * @param user User name
         * @param password Password
         */
        connection(const std::string& addr,
                   const std::string& user,
                   const std::string& password);
        ~connection();

        connection(const connection&) = delete;
        connection& operator=(const connection&) = delete;

        MYSQL* native() { return conn_; }
    private:
        MYSQL* conn_;
    };


    /** @class schema
     *
     * Prepare and clean up test schema.
     */
    class schema
    {
    public:
        schema(connection&);
        ~schema();
        void prepare();
        void cleanup();
    private:
        connection& conn_;
    };

    /** @class stmt
     *
     * Encapsulates MYSQL_STMT object and defines several helper methods
     * to run DML statements with PS protocol.
     */
    class stmt
    {
    public:
        stmt(connection&, bool verbose = false);
        ~stmt();

        stmt(const stmt&) = delete;
        stmt& operator=(const stmt&) = delete;

        void prepare(const std::string&);
        void run_insert(const std::vector<int>&);
        void run_replace(const std::vector<int>&);
        void run_update(const std::vector<int>&);
        void run_update_range(const std::vector<int>&);
        void run_delete(const std::vector<int>&);
    private:
        void bind_and_execute(const std::vector<int>&, const std::vector<int>& = std::vector<int>());
        connection& conn_;
        MYSQL_STMT* stmt_;
        bool verbose_;
    };

}

mps::hostport get_host_port(const std::string& addr)
{
    auto sep(addr.find(':'));
    if (sep == std::string::npos)
    {
        throw std::runtime_error(std::string("Invalid addr: ") + addr);
    }
    auto host(addr.substr(0, sep));
    auto port(addr.substr(sep + 1));
    return mps::hostport{host, std::stoi(port)};
}

mps::connection::connection(const std::string& addr,
                            const std::string& user,
                            const std::string& password)
    : conn_(mysql_init(nullptr))
{
    auto hp{get_host_port(addr)};
    if (mysql_real_connect(
            conn_, hp.host.c_str(), user.c_str(),
            password.empty() ? nullptr : password.c_str(),
            "test", hp.port, nullptr, CLIENT_INTERACTIVE) == nullptr)
    {
        mysql_close(conn_);
        throw std::runtime_error("Could not connect");
    }
}

mps::connection::~connection()
{
    mysql_close(conn_);
}

mps::schema::schema(mps::connection& conn)
    : conn_(conn)
{ }

mps::schema::~schema()
{ }

void mps::schema::prepare()
{
    if (mysql_query(conn_.native(), "DROP TABLE IF EXISTS t1") != 0)
    {
        throw std::runtime_error(mysql_error(conn_.native()));
    }

    if (mysql_query(
            conn_.native(),
            "CREATE TABLE t1 (f1 INT PRIMARY KEY, f2 INT, op VARCHAR(10))") != 0)
    {
        throw std::runtime_error(mysql_error(conn_.native()));
    }
}

void mps::schema::cleanup()
{
    if (mysql_query(conn_.native(), "DROP TABLE t1") != 0)
    {
        throw std::runtime_error(mysql_error(conn_.native()));
    }
}

mps::stmt::stmt(mps::connection& conn, bool verbose)
    : conn_(conn)
    , stmt_(mysql_stmt_init(conn_.native()))
    , verbose_(verbose)
{
    if (stmt_ == nullptr)
    {
        throw std::bad_alloc();
    }
}

mps::stmt::~stmt()
{
    mysql_stmt_close(stmt_);
}

void mps::stmt::prepare(const std::string& str)
{
    if (verbose_) std::cout << "Preparing: " << str << "\n";
    if (mysql_stmt_prepare(stmt_, str.c_str(), -1) != 0)
    {
        throw std::runtime_error(
            std::string("Failed to prepare: ") + mysql_stmt_error(stmt_));
    }
}

template <typename T>
static void* buffer_cast(T* t)
{
    return const_cast<void*>(reinterpret_cast<const void*>(t));
}

void mps::stmt::bind_and_execute(
    const std::vector<int>& values_1, const std::vector<int>& values_2)
{
    if (verbose_) std::cout << "Bind and execute: v1 size: " << values_1.size()
                            << " v2 size: " << values_2.size() << "\n";
    assert(not values_1.empty());
    MYSQL_BIND bind[2];
    std::vector<char> f1_ind(values_1.size());
    std::vector<char> f2_ind(values_1.size());
    std::memset(&bind, 0, sizeof(bind));
    if (not values_1.empty()) {
        bind[0].u.indicator = &f1_ind[0];
        bind[0].buffer_type = MYSQL_TYPE_LONG;
        bind[0].buffer = buffer_cast(values_1.data());
    }
    if (not values_2.empty())
    {
        bind[1].u.indicator = &f2_ind[0];
        bind[1].buffer_type = MYSQL_TYPE_LONG;
        bind[1].buffer = buffer_cast(values_2.data());
    }

    if (mysql_stmt_bind_param(stmt_, bind) != 0)
    {
        throw std::runtime_error(mysql_stmt_error(stmt_));
    }

    unsigned int numrows(values_1.size());
    mysql_stmt_attr_set(stmt_, STMT_ATTR_ARRAY_SIZE, &numrows);
    if (mysql_stmt_execute(stmt_) != 0)
    {
        std::runtime_error(mysql_stmt_error(stmt_));
    }
    if (mysql_stmt_reset(stmt_) != 0)
    {
        throw std::runtime_error(mysql_stmt_error(stmt_));
    }
}

void mps::stmt::run_insert(const std::vector<int>& pk_values)
{
    if (pk_values.empty()) return;
    prepare("INSERT INTO t1 VALUES (?, ?, 'insert')");
    std::vector<int> f2_values(pk_values.size(), mysql_thread_id(conn_.native()));
    bind_and_execute(pk_values, f2_values);
}

void mps::stmt::run_replace(const std::vector<int>& pk_values)
{
    if (pk_values.empty()) return;
    prepare("REPLACE INTO t1 VALUES (?, ?, 'replace')");
    std::vector<int> f2_values(pk_values.size(), mysql_thread_id(conn_.native()));
    bind_and_execute(pk_values, f2_values);
}

void mps::stmt::run_update(const std::vector<int>& pk_values)
{
    if (pk_values.empty()) return;
    prepare("UPDATE t1 SET f2 = ?, op = 'update' WHERE f1 = ?");
    std::vector<int> f2_values(pk_values.size(), mysql_thread_id(conn_.native()));
    bind_and_execute(f2_values, pk_values);
}

void mps::stmt::run_update_range(const std::vector<int>& pk_values)
{
    if (pk_values.empty()) return;
    prepare("UPDATE t1 SET f2 = ?, op = 'update' WHERE f1 > ?");
    std::vector<int> f2_values(pk_values.size(), mysql_thread_id(conn_.native()));
    bind_and_execute(f2_values, pk_values);
}

void mps::stmt::run_delete(const std::vector<int>& pk_values)
{
    if (pk_values.empty()) return;
    prepare("DELETE FROM t1 WHERE f1 = ?");
    bind_and_execute(pk_values);
}

//
// Config parameters
//
static std::string addr1;
static std::string addr2;
static auto concurrent_duration = std::chrono::seconds(10);

static void prepare()
{
    mps::connection conn(addr1, "root", "");
    mps::schema schema(conn);
    schema.prepare();
}

static void cleanup()
{
    mps::connection conn(addr1, "root", "");
    mps::schema schema(conn);
    schema.cleanup();
}

// Single client bulk regular and bulk statements.
static void run_simple()
{
    mps::connection conn(addr1, "root", "");
    mps::stmt stmt(conn, true);

    stmt.run_insert(std::vector<int>{4});
    stmt.run_insert(std::vector<int>{1, 2, 3});
    stmt.run_replace(std::vector<int>{1});
    stmt.run_replace(std::vector<int>{2, 3});

    stmt.run_update(std::vector<int>{2});
    stmt.run_update(std::vector<int>{3, 4});
    stmt.run_delete(std::vector<int>{4});
    stmt.run_delete(std::vector<int>{3, 2});
}

int gen_rand()
{
    return (std::rand() % 10);
}

static void run_one_thread(const std::string& addr)
{
    auto run_until{std::chrono::steady_clock::now() + concurrent_duration};
    mps::connection conn(addr, "root", "");

    while (run_until > std::chrono::steady_clock::now())
    {
        mps::stmt stmt(conn);
        stmt.run_insert(std::vector<int>{gen_rand(), gen_rand()});
        stmt.run_replace(std::vector<int>{gen_rand(), gen_rand()});
        stmt.run_update(std::vector<int>{gen_rand()});
        stmt.run_update(std::vector<int>{gen_rand(), gen_rand()});
        stmt.run_update_range(std::vector<int>{gen_rand(), gen_rand()});
        stmt.run_delete(std::vector<int>{gen_rand()});
        stmt.run_delete(std::vector<int>{gen_rand(), gen_rand()});
    }
}

static void run_concurrent()
{
    std::thread t1{run_one_thread, addr1};
    std::thread t2{run_one_thread, addr2};

    t1.join();
    t2.join();
}

static void print_help(const std::string prog_name)
{
    std::cerr << "Usage: " << prog_name << " <runtime> <addr1> <addr2>\n";
    std::cerr << "\n\t<addr1> <addr2> are in form of <ip>:<port>.\n";
    std::cerr << "\n\tIf <runtime> equals zero, simple testcase is run,\n"
              << "\totherwise concurrent\n\n";
}

static void parse_args(int argc, char** argv)
{
    if (argc == 2 && std::string(argv[1]) == "help")
    {
        print_help(argv[0]);
        exit(0);
    }
    if (argc != 4)
    {
        print_help(argv[0]);
        exit(1);
    }
    concurrent_duration = std::chrono::seconds(std::stoi(argv[1]));
    addr1 = argv[2];
    addr2 = argv[3];
}

int main(int argc, char** argv)
{
    try
    {
        parse_args(argc, argv);
        prepare();
        if (concurrent_duration == std::chrono::seconds(0))
        {
            run_simple();
        }
        if (concurrent_duration > std::chrono::seconds(0))
        {
            run_concurrent();
        }
        cleanup();
    }
    catch (const std::exception& e)
    {
        std::cerr << "Error: " << e.what() << "\n";
        return 1;
    }
    return 0;
}
