From 16e4c20d396e862506b2b97a3f0fa9924c0dc9d4 Mon Sep 17 00:00:00 2001
From: Ben Rowland <ben.rowland@openmarket.com>
Date: Wed, 25 Feb 2015 20:10:44 +0000
Subject: [PATCH] add support for batch rewrite with on duplicate key update

---
 .../org/mariadb/jdbc/MySQLPreparedStatement.java   |   28 ++++++++++++++-
 src/test/java/org/mariadb/jdbc/DriverTest.java     |   37 ++++++++++++++++++++
 2 files changed, 64 insertions(+), 1 deletion(-)

diff --git a/src/main/java/org/mariadb/jdbc/MySQLPreparedStatement.java b/src/main/java/org/mariadb/jdbc/MySQLPreparedStatement.java
index 221ee2f..1c5d5f1 100644
--- a/src/main/java/org/mariadb/jdbc/MySQLPreparedStatement.java
+++ b/src/main/java/org/mariadb/jdbc/MySQLPreparedStatement.java
@@ -230,6 +230,12 @@ public class MySQLPreparedStatement extends MySQLStatement implements PreparedSt
     		isRewriteable = false;
     		return;
     	}
+        int onDuplicateKeyIdx = getOnDuplicateKeyClauseIdx(sql);
+        if (onDuplicateKeyIdx != -1 && sql.toUpperCase().indexOf("?", onDuplicateKeyIdx) != -1) {
+            // parameters cannot be supported for batch rewriting within the ON DUPLICATE KEY UPDATE clause
+            isRewriteable = false;
+            return;
+        }
     	if (firstRewrite == null) {
     		firstRewrite = sql.substring(0, index);
     	}
@@ -248,16 +254,36 @@ public class MySQLPreparedStatement extends MySQLStatement implements PreparedSt
     	if(isRewriteable) {
     		result = new StringBuilder("");
     		result.append(firstRewrite);
+            String onDuplicateKeyClause = null;
     		for (MySQLPreparedStatement mySQLPS : batchPreparedStatements) {
     			String query = mySQLPS.dQuery.toSQL();
-    			result.append(query.substring(getInsertIncipit(query)));
+                int onDuplicateKeyClauseIdx = getOnDuplicateKeyClauseIdx(query);
+                if(onDuplicateKeyClause == null)  {
+                    onDuplicateKeyClause = onDuplicateKeyClauseIdx != -1 ? query.substring(onDuplicateKeyClauseIdx) : "";
+                }
+                int valuesSubstringLimit = onDuplicateKeyClauseIdx != -1 ? onDuplicateKeyClauseIdx : query.length();
+                result.append(query.substring(getInsertIncipit(query), valuesSubstringLimit));
     			result.append(",");
     		}
     		result.deleteCharAt(result.length() - 1);
+            if(!"".equals(onDuplicateKeyClause)) {
+                result.append(onDuplicateKeyClause);
+            }
     	}
     	return (result == null ? null : result.toString());
     }
 
+    /**
+     * Returns the index of the beginning of the ON DUPLICATE KEY UPDATE clause within an INSERT statement if present,
+     * otherwise -1.
+     * @param query
+     * @return index of the beginning of the ON DUPLICATE KEY UPDATE clause if present, otherwise -1.
+     */
+    protected int getOnDuplicateKeyClauseIdx(String query) {
+        int endOfValuesIdx = query.indexOf(")", getInsertIncipit(query));
+        return query.toUpperCase().indexOf("ON DUPLICATE KEY UPDATE", endOfValuesIdx);
+    }
+
     @Override
     public int[] executeBatch() throws SQLException {
         if (batchPreparedStatements == null) {
diff --git a/src/test/java/org/mariadb/jdbc/DriverTest.java b/src/test/java/org/mariadb/jdbc/DriverTest.java
index c9e9405..b9a85b0 100644
--- a/src/test/java/org/mariadb/jdbc/DriverTest.java
+++ b/src/test/java/org/mariadb/jdbc/DriverTest.java
@@ -797,6 +797,43 @@ public class DriverTest extends BaseTest{
     }
 
     @Test
+    public void testBatchLoopWithDupKeyRewriting() throws SQLException {
+        Connection conn = null;
+        PreparedStatement ps = null;
+        ResultSet rs = null;
+        try {
+            conn = openNewConnection(connURI + "&rewriteBatchedStatements=true");
+            conn.createStatement().execute("drop table if exists rewritetest3");
+            conn.createStatement().execute("create table rewritetest3 (id int not null primary key, a varchar(10)) engine=innodb");
+            ps = conn.prepareStatement("insert into rewritetest3 values (?,?) on duplicate key update a=values(a)");
+
+            for(int i=0; i<3; i++) {
+                ps.setInt(1, 1);
+                StringBuilder sb = new StringBuilder();
+                // insert values of increasing length to test the values clause is correctly indexed each time
+                for(int j=0; j<=i; j++) {
+                    sb.append("a");
+                }
+                ps.setString(2, sb.toString());
+                ps.addBatch();
+            }
+            ps.executeBatch();
+
+            rs = conn.createStatement().executeQuery("select * from rewritetest3");
+            assertTrue(rs.next());
+            assertEquals(1, rs.getInt("id"));
+            assertEquals("aaa", rs.getString("a"));
+        } finally {
+            if (rs != null)
+                rs.close();
+            if (ps != null)
+                ps.close();
+            if (conn != null)
+                conn.close();
+        }
+    }
+
+    @Test
     public void testPreparedStatementsWithComments() throws SQLException {
         connection.createStatement().execute("drop table if exists commentPreparedStatements");
         connection.createStatement().execute(
-- 
1.7.9.5

