/**
 * @file test_stmt_large_param_num.cpp
 * @brief Issue a simple prepared statement with the supplied number of parameters.
 * @date 2022-03-11
 */

#include <vector>
#include <string>
#include <stdio.h>
#include <cstring>
#include <unistd.h>
#include <time.h>
#include <iostream>
#include <thread>

#include <mysql.h>
#include <getopt.h>

#define MYSQL_QUERY(mysql, query) \
	do { \
		if (mysql_query(mysql, query)) { \
			fprintf(stderr, "File %s, line %d, Error: %s\n", __FILE__, __LINE__, mysql_error(mysql)); \
			return EXIT_FAILURE; \
		} \
	} while(0)

const uint32_t STRING_SIZE=32;
using std::string;

int perform_stmt_select(MYSQL* mysql, const string& query, uint32_t n_params) {
	int res = EXIT_SUCCESS;

	MYSQL_STMT *stmt = mysql_stmt_init(mysql);
	if (!stmt) {
		fprintf(stderr, "Line %d, 'mysql_stmt_init()' failed, out of memory\n", __LINE__);
		res = EXIT_FAILURE;
		goto exit;
	}

	if (mysql_stmt_prepare(stmt, query.c_str(), strlen(query.c_str()))) {
		fprintf(stderr, "Line %d, 'mysql_stmt_prepare()' failed: '%s'\n", __LINE__, mysql_error(mysql));
		mysql_close(mysql);
		res = EXIT_FAILURE;
		goto exit;
	}

	{
		std::vector<MYSQL_BIND> bind_params(n_params);
		std::vector<int64_t> data_param(n_params, 0);

		for (uint32_t i = 0; i < data_param.size(); i++) {
			data_param[i] = i;
		}

		for (int i = 0; i < n_params; i++) {
			memset(&bind_params[i], 0, sizeof(MYSQL_BIND));

			bind_params[i].buffer_type = MYSQL_TYPE_LONGLONG;
			bind_params[i].buffer = (char *)&data_param[i];
			bind_params[i].buffer_length = sizeof(int64_t);
		}

		if (mysql_stmt_bind_param(stmt, &bind_params[0])) {
			fprintf(stderr, "Line %d, 'mysql_stmt_bind_result()' failed: '%s'\n", __LINE__, mysql_stmt_error(stmt));
			res = EXIT_FAILURE;
			goto exit;
		}

		if (mysql_stmt_execute(stmt)) {
			fprintf(stderr, "Line %d, 'mysql_stmt_execute()' failed: '%s'\n", __LINE__, mysql_stmt_error(stmt));
			res = EXIT_FAILURE;
			goto exit;
		}

		MYSQL_BIND bind[3];
		int data_id;
		int64_t data_c1;
		char data_c2[STRING_SIZE];
#ifdef LIBMYSQLCLIENT
		bool is_null[3];
		bool error[3];
#else
		char is_null[3];
		char error[3];
#endif
		long unsigned int length[3];
		memset(bind, 0, sizeof(bind));

		bind[0].buffer_type = MYSQL_TYPE_LONG;
		bind[0].buffer = (char *)&data_id;
		bind[0].buffer_length = sizeof(int);
		bind[0].is_null = &is_null[0];
		bind[0].length = &length[0];

		bind[1].buffer_type = MYSQL_TYPE_LONGLONG;
		bind[1].buffer = (char *)&data_c1;
		bind[1].buffer_length = sizeof(int64_t);
		bind[1].is_null = &is_null[1];
		bind[1].length = &length[1];

		bind[2].buffer_type = MYSQL_TYPE_STRING;
		bind[2].buffer = (char *)&data_c2;
		bind[2].buffer_length = STRING_SIZE;
		bind[2].is_null = &is_null[2];
		bind[2].length = &length[2];
		bind[2].error = &error[2];

		if (mysql_stmt_bind_result(stmt, bind)) {
			fprintf(stderr, "Line %d, 'mysql_stmt_bind_result()' failed: '%s'\n", __LINE__, mysql_stmt_error(stmt));
			res = EXIT_FAILURE;
			goto exit;
		}

		if (mysql_stmt_fetch(stmt) == 1) {
			fprintf(stderr, "Line %d, 'mysql_stmt_fetch()' failed: '%s'\n", __LINE__, mysql_stmt_error(stmt));
			res = EXIT_FAILURE;
			goto exit;
		}

		bool data_match_expected =
			(data_id == static_cast<int64_t>(1)) &&
			(data_c1 == static_cast<int64_t>(100)) &&
			(strcmp(data_c2, "abcde") == 0);

		if (data_match_expected == false) {
			fprintf(stderr,
				"Prepared statement SELECT result didn't matched expected -"
				" Exp=(id:1, c1:100, c2:'abcde'), Act=(id:%d, c1:%ld, c2:'%s')",
				data_id, data_c1, data_c2
			);
			res = EXIT_FAILURE;
			goto exit;
		}
	}

exit:
	if (stmt) { mysql_stmt_close(stmt); }

	return res;
}

string build_select_query(const string& table_name, const uint32_t n_params) {
	string t_query { "SELECT * FROM " + table_name + " WHERE id IN (" };

	for (uint32_t i = 0; i < n_params; i++) {
		t_query += "?";

		if (i != n_params - 1) {
			t_query += ",";
		}
	}

	t_query += ")";

	return t_query;
}

struct conn_opts_t {
	string host;
	string user;
	string pass;
	uint32_t port;
	uint32_t s_param_num;
};

void usage(const char* name) {
	std::cout << "Usage: " << name
		<< " -h <mysql host> -u <mysql user> -p <mysql password> -P <mysql port> -N <stmt_param_num>\n";
}

conn_opts_t command_line_options(int argc, char** argv) {
	if (argc == 1) {
		usage(argv[0]);
		return {};
	}

	string host {};
	string user {};
	string pass {};
	uint32_t port = 0;
	uint32_t s_param_num = 2000;

	bool inv_arg = false;
	int c = 0;

	while (-1 != (c = ::getopt(argc, argv, "h:u:p:P:N:"))) {
		switch (c) {
			case 'h':
				if (optarg == NULL) { inv_arg = true; }
				else { host = optarg; }
				break;
			case 'u':
				if (optarg == NULL) { inv_arg = true; }
				else { user = optarg; }
				break;
			case 'p':
				if (optarg == NULL) { inv_arg = true; }
				else {
					pass = optarg;
					memset(optarg,'x',strlen(optarg));
				}
				break;
			case 'P':
				if (optarg == NULL) { inv_arg = true; }
				else { port = std::stoi(optarg); }
				break;
			case 'N':
				if (optarg == NULL) { inv_arg = true; }
				else { s_param_num = std::stoi(optarg); }
				break;
			default:
				inv_arg = true;
		}

		if (inv_arg) {
			usage(argv[0]);
			break;
		}
	}

	if (inv_arg) {
		return {};
	} else {
		return conn_opts_t { host, user, pass, port, s_param_num };
	}
}

int main(int argc, char** argv) {
	MYSQL* mysql = mysql_init(NULL);
	if (!mysql) {
		fprintf(stderr, "File %s, line %d, Error: %s\n", __FILE__, __LINE__, mysql_error(mysql));
		return EXIT_FAILURE;
	}

	conn_opts_t co { command_line_options(argc, argv) };
	if (co.host.empty()) { return EXIT_FAILURE; }

#ifdef LIBMYSQLCLIENT
	enum mysql_ssl_mode ssl_mode = SSL_MODE_DISABLED;
	mysql_options(mysql, MYSQL_OPT_SSL_MODE, &ssl_mode);
#endif

	const char* c_host = co.host.c_str(); 
	const char* c_user = co.user.c_str(); 
	const char* c_pass = co.pass.c_str(); 

	if (!mysql_real_connect(mysql, c_host, c_user, c_pass, NULL, co.port, NULL, 0)) {
		fprintf(stderr, "File %s, line %d, Error: %s\n", __FILE__, __LINE__, mysql_error(mysql));
		return EXIT_FAILURE;
	}

	// *************************************************************************
	//                 Insert data in the table to be queried
	// *************************************************************************

	MYSQL_QUERY(mysql, "CREATE DATABASE IF NOT EXISTS test");
	MYSQL_QUERY(mysql, "DROP TABLE IF EXISTS test.stmt_params");
	MYSQL_QUERY(mysql,
		"CREATE TABLE IF NOT EXISTS test.stmt_params"
		" (id INT NOT NULL AUTO_INCREMENT PRIMARY KEY, `c1` BIGINT, `c2` varchar(32))"
	);
	MYSQL_QUERY(mysql, "INSERT INTO test.stmt_params (c1, c2) VALUES (100, 'abcde')");

	// *************************************************************************

	// *************************************************************************
	//                      Perform the offending STMT 
	// *************************************************************************

	string query_1 { build_select_query("test.stmt_params", co.s_param_num) };
	int query_res = perform_stmt_select(mysql, query_1, co.s_param_num);

	// *************************************************************************

	return query_res;
}
