package bugtest;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import javax.sql.DataSource;

import com.google.common.collect.Lists;

import org.mariadb.jdbc.MariaDbDataSource;

/**
 * This class can be used to demonstrate an apparent bug in many MariaDB versions (tested on
 * 10.3.34, 10.5.15, 10.6.7, 10.7.3, 10.8.2)
 *
 * <p>The bug shows up when performing a query to retrieve records from a table matching an IN
 * condition. The values list in the condition is a collection of placeholders, so a prepared
 * statement is used to execute the query.</p>
 *
 * <p>If the `useServerPrepStmts` connection property is enabled, and if the request size
 * is >= 1000, these qeuries return no records.</p>
 *
 * <p>To demonstrate the bug, use arguments that include:</p>
 * <ul>
 *     <li>seedCount >= 1000</li>
 *     <li>batchSize >= 1000</li>
 *     <li>addUseServerPrepStmts = true</li>
 * </ul>
 */
public class BugDemo {

    private static final String[] USAGE_TEXT = new String[]{
            "Usage: BugDemo batchSize seedCount addUseServerPrepStmts server port user password dbName tblName",
            "where",
            "  batchSize = size of batches to retrieve records",
            "  seedCount = # of records to seed into the table",
            "  addUseServerPrepStmts = true/false - whether to set useServerPrepStmts=true",
            "  server = name/IP address of DB server",
            "  port = port number for DB server",
            "  user = user name for DB login",
            "  password = password for DB login",
            "  dbName = name of database where data will be placed (must already exist)",
            "  tblName = name of table to hold data (created if needed, else truncated for test)"
    };

    public static void main(String[] args) throws SQLException {
        if (args.length == 1 && args[0].equals("-h")) {
            usage();
            return;
        }
        try {
            int i = 0;
            int batchSize = Integer.parseInt(args[i++]);
            int seedCount = Integer.parseInt(args[i++]);
            boolean addUseServerPrepStmts = Boolean.parseBoolean(args[i++]);
            String server = args[i++];
            int port = Integer.parseInt(args[i++]);
            String user = args[i++];
            String password = args[i++];
            String dbName = args[i++];
            String tblName = args[i++];
            new BugDemo().runDemo(batchSize, seedCount, addUseServerPrepStmts, server, port, user,
                    password, dbName, tblName);
        } catch (Exception e) {
            System.err.printf("BugDemo failed to complete: %s", e.getMessage());
            usage();
            System.exit(1);
        }
    }

    private static void usage() {
        System.err.println(String.join("\n", USAGE_TEXT));
    }

    private void runDemo(int batchSize, int seedCount, boolean addUseServerPrepStmts,
            String server, int port, String user, String password, String dbName, String tblName)
            throws SQLException {
        DataSource ds = getDataSource(server, port, user, password, dbName, addUseServerPrepStmts);
        reportVersion(ds);
        populateData(ds, tblName, seedCount);
        List<Long> oids = getAllIds(ds, tblName);
        System.out.printf("Total records available: %d\n", oids.size());
        Map<Long, String> recs = getRecords(ds, tblName, oids, batchSize);
        System.out.printf("Total records retrieved: %d\n", recs.size());
        if (recs.size() == seedCount) {
            System.out.println("All records retrieved: bug was not triggered");
        } else {
            System.out.println("Not all records retrieved: bug was triggered");
        }
    }

    private void reportVersion(DataSource ds) throws SQLException {
        try (Connection conn = ds.getConnection();
             ResultSet rs = conn.createStatement().executeQuery("SHOW VARIABLES LIKE 'version'")) {
            if (rs.next()) {
                System.out.printf("DB version: %s\n", rs.getString(2));
            } else {
                System.out.printf("Failed to retrieve DB version\n");
            }
        }
    }

    private void populateData(DataSource ds, String tblName, int seedCount) throws SQLException {
        try (Connection conn = ds.getConnection()) {
            conn.createStatement().execute(String.format(
                    "CREATE TABLE IF NOT EXISTS %s(id bigint, name varchar(30))", tblName));
            conn.createStatement().execute(String.format("TRUNCATE TABLE %s", tblName));
            String valuesList = IntStream.range(0, seedCount)
                    .mapToObj(i -> String.format("(%d,'%s')", i, "Record #" + i))
                    .collect(Collectors.joining(","));
            conn.createStatement().execute(
                    String.format("INSERT INTO %s (id, name) VALUES %s", tblName, valuesList));
        }
    }

    private List<Long> getAllIds(DataSource ds, String tblName) throws SQLException {
        try (Connection conn = ds.getConnection();
             PreparedStatement ps = conn.prepareStatement("SELECT id FROM " + tblName);
             ResultSet rs = ps.executeQuery()) {
            List<Long> result = new ArrayList<>();
            while (rs.next()) {
                result.add(rs.getLong(1));
            }
            return result;
        }
    }

    private Map<Long, String> getRecords(DataSource ds, String tblName, List<Long> oids,
            int batchSize)
            throws SQLException {
        Map<Long, String> recs = new HashMap<>();
        for (List<Long> ids : Lists.partition(oids, batchSize)) {
            Map<Long, String> batchResults = runQuery(ds, tblName, ids);
            System.out.printf("Batch retrieved record count: %d\n", batchResults.size());
            recs.putAll(batchResults);
        }
        return recs;
    }

    private Map<Long, String> runQuery(DataSource ds, String tblName, List<Long> ids)
            throws SQLException {

        String sql = String.format("SELECt id, name FROM %s WHERE id IN (%s)", tblName,
                ids.stream().map(id -> "?").collect(Collectors.joining(",")));
        try (Connection conn = ds.getConnection();
             PreparedStatement ps = conn.prepareStatement(sql)) {
            for (int i = 0; i < ids.size(); i++) {
                ps.setLong(i + 1, ids.get(i));
            }
            try (ResultSet rs = ps.executeQuery()) {
                Map<Long, String> results = new HashMap<>();
                while (rs.next()) {
                    results.put(rs.getLong(1), rs.getString(2));
                }
                return results;
            }
        }
    }

    private DataSource getDataSource(String server, int port, String user, String password,
            String dbName, boolean addUseServerPrepStmts) throws SQLException {
        MariaDbDataSource dataSource = new MariaDbDataSource(server, port, dbName);
        dataSource.setUser(user);
        dataSource.setPassword(password);
        if (addUseServerPrepStmts) {
            dataSource.setProperties("useServerPrepStmts=true");
        }
        return dataSource;
    }
}
