package main

import (
	"database/sql"
	"fmt"
	"log"
	"strings"

	_ "github.com/go-sql-driver/mysql"
	"github.com/ory/dockertest"
)

func spinDocker(image, version string) (string, func(), error) {
	pool, err := dockertest.NewPool("")
	if err != nil {
		return "", nil, err
	}
	resource, err := pool.Run(image, version, []string{
		fmt.Sprintf("MYSQL_ROOT_PASSWORD=secret"),
		fmt.Sprintf("MYSQL_DATABASE=testing"),
	})
	if err != nil {
		return "", nil, err
	}
	dsn := fmt.Sprintf(
		"root:secret@(%s:3306)/testing",
		resource.Container.NetworkSettings.IPAddress,
	)
	if err = pool.Retry(func() error {
		db, err := sql.Open("mysql", dsn)
		if err != nil {
			return err
		}
		defer db.Close()
		return db.Ping()
	}); err != nil {
		pool.Purge(resource)
		return "", nil, err
	}

	return dsn, func() {
		pool.Purge(resource)
	}, nil
}

func connectAndPrepare(dsn string, blob []byte) (*sql.DB, error) {
	db, err := sql.Open("mysql", dsn)
	if err != nil {
		return nil, err
	}
	_, err = db.Query("CREATE TABLE x (b VARBINARY(16) NOT NULL, PRIMARY KEY (b))")
	if err != nil {
		return nil, err
	}
	if _, err := db.Query("INSERT INTO x VALUES(?)", blob); err != nil {
		return nil, err
	}
	return db, nil
}

func readBlob(db *sql.DB, blob []byte, count int) ([][]byte, error) {
	blobs := make([]interface{}, 0)
	placeholders := make([]string, 0)
	for i := 0; i < count; i++ {
		blobs = append(blobs, blob)
		placeholders = append(placeholders, "?")
	}
	query := fmt.Sprintf("SELECT b FROM x WHERE b IN (%s)", strings.Join(placeholders, ","))

	r, err := db.Query(query, blobs...)
	if err != nil {
		return nil, err
	}
	defer r.Close()

	output := make([][]byte, 0)
	for r.Next() {
		var u []byte
		if err := r.Scan(&u); err != nil {
			return output, err
		}
		output = append(output, u)
	}
	return output, r.Err()
}

func main() {
	blob := []byte{0xff}
	ints := []int{999, 1000}
	testIt("mariadb", "10.2", ints, blob)
	testIt("mariadb", "10.3.2", ints, blob)
	testIt("mariadb", "10.3.3", ints, blob)
	testIt("mariadb", "10.3", ints, blob)
	testIt("mariadb", "10.4", ints, blob)
	testIt("mysql", "5.7", ints, blob)
	testIt("mysql", "8.0", ints, blob)
}

func testIt(image string, version string, counts []int, blob []byte) {
	dsn, cleanup, err := spinDocker(image, version)
	defer cleanup()

	if err != nil {
		log.Fatal(err)
	}

	db, err := connectAndPrepare(dsn, blob)
	if err != nil {
		log.Fatal(err)
	}

	for _, count := range counts {
		found, err := readBlob(db, blob, count)
		if err != nil {
			log.Fatal(err)
		}
		result := "FAILURE"
		if len(found) > 0 {
			result = "SUCCESS"
		}
		fmt.Println(image, version, "count", count, "blob", blob, "->", found, result)
	}
}
