#include <mysql.h>
#include <stdio.h>
#include <pthread.h>
#include <string.h>
#include <stdlib.h>

#define OK 0
#define FAIL 1
#define SKIP 2


#define check_mysql_rc(rc, mysql)\
if (rc){\
printf("error: %s\n", mysql_error(mysql));\
}
const char *ssluser= "ssluser";
const char *sslpw= "sslpw";
char sslhost[128];

pthread_mutex_t LOCK_test;

const char *schema= "testc";
const char *hostname= "127.0.0.1";
const char *username= "root";
const char *password= NULL;
int port= 0;
const char *socketname= NULL;


#ifndef WIN32
static void ssl_thread(void *dummy)
#else
DWORD WINAPI ssl_thread(void *dummy)
#endif
{
  MYSQL *mysql= NULL;

  mysql_thread_init();
  
  if (!(mysql= mysql_init(NULL)))
  {  
    mysql_thread_end();
    pthread_exit(NULL);
  }
  mysql_ssl_set(mysql, 0, 0, "./certs/ca-cert.pem", 0, 0);

  if(!mysql_real_connect(mysql, hostname, ssluser, sslpw, schema,
          port, socketname, 0))
  {
    printf(">Error: %s\n", mysql_error(mysql));
    mysql_close(mysql);
    mysql_thread_end();
    pthread_exit(NULL);
  }

  pthread_mutex_lock(&LOCK_test);
  mysql_query(mysql, "UPDATE ssltest SET a=a+1");
  pthread_mutex_unlock(&LOCK_test);
  mysql_close(mysql);
  mysql_thread_end();
  pthread_exit(0);
  return;
}

static int test_ssl_threads(MYSQL *mysql)
{
  int i, rc;
#ifndef WIN32
  pthread_t threads[50];
#else
  HANDLE hthreads[50];
  DWORD dthreads[50];
#endif
  MYSQL_RES *res;
  MYSQL_ROW row;
  
  rc= mysql_query(mysql, "DROP TABLE IF exists ssltest");
  check_mysql_rc(rc, mysql);
  rc= mysql_query(mysql, "CREATE TABLE ssltest (a int)");
  check_mysql_rc(rc, mysql);
  rc= mysql_query(mysql, "INSERT into ssltest VALUES (0)");
  check_mysql_rc(rc, mysql);
  pthread_mutex_init(&LOCK_test, NULL);

  pthread_mutex_init(&LOCK_test, NULL);

  for (i=0; i < 50; i++)
  {
#ifndef WIN32
    pthread_create(&threads[i], NULL, (void *)ssl_thread, NULL);
#else
    hthreads[i]= CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE)ssl_thread, NULL, 0, &dthreads[i]);
    if (hthreads[i]==NULL)
      printf("error while starting thread\n");
#endif
  }
  for (i=0; i < 50; i++)
#ifndef WIN32
    pthread_join(threads[i], NULL);
#else
    WaitForSingleObject(hthreads[i], INFINITE);
#endif

  pthread_mutex_destroy(&LOCK_test);

  rc= mysql_query(mysql, "SELECT a FROM ssltest");
  check_mysql_rc(rc, mysql);
  res= mysql_store_result(mysql);
  row= mysql_fetch_row(res);
  printf("Found: %s\n", row[0]);
  if (strcmp(row[0], "50") != 0)
    printf("Expected 50\n");
  mysql_free_result(res);
  return OK;
}

int main()
{
  MYSQL *mysql;

  mysql_library_init(0,NULL,0);
  
  mysql= mysql_init(NULL);
  if(!mysql_real_connect(mysql, hostname, username, password, schema,
                         port, socketname, 0))
  {
    printf("Error: %s\n", mysql_error(mysql));
    exit(-1);
  }
  test_ssl_threads(mysql);

  mysql_library_end();

}
