View Javadoc
1   /*
2    *    Copyright 2009-2023 the original author or authors.
3    *
4    *    Licensed under the Apache License, Version 2.0 (the "License");
5    *    you may not use this file except in compliance with the License.
6    *    You may obtain a copy of the License at
7    *
8    *       https://www.apache.org/licenses/LICENSE-2.0
9    *
10   *    Unless required by applicable law or agreed to in writing, software
11   *    distributed under the License is distributed on an "AS IS" BASIS,
12   *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   *    See the License for the specific language governing permissions and
14   *    limitations under the License.
15   */
16  package org.apache.ibatis.jdbc;
17  
18  import java.io.BufferedReader;
19  import java.io.PrintWriter;
20  import java.io.Reader;
21  import java.sql.Connection;
22  import java.sql.ResultSet;
23  import java.sql.ResultSetMetaData;
24  import java.sql.SQLException;
25  import java.sql.SQLWarning;
26  import java.sql.Statement;
27  import java.util.regex.Matcher;
28  import java.util.regex.Pattern;
29  
30  /**
31   * This is an internal testing utility.<br>
32   * You are welcome to use this class for your own purposes,<br>
33   * but if there is some feature/enhancement you need for your own usage,<br>
34   * please make and modify your own copy instead of sending us an enhancement request.<br>
35   *
36   * @author Clinton Begin
37   */
38  public class ScriptRunner {
39  
40    private static final String LINE_SEPARATOR = System.lineSeparator();
41  
42    private static final String DEFAULT_DELIMITER = ";";
43  
44    private static final Pattern DELIMITER_PATTERN = Pattern
45        .compile("^\\s*((--)|(//))?\\s*(//)?\\s*@DELIMITER\\s+([^\\s]+)", Pattern.CASE_INSENSITIVE);
46  
47    private final Connection connection;
48  
49    private boolean stopOnError;
50    private boolean throwWarning;
51    private boolean autoCommit;
52    private boolean sendFullScript;
53    private boolean removeCRs;
54    private boolean escapeProcessing = true;
55  
56    private PrintWriter logWriter = new PrintWriter(System.out);
57    private PrintWriter errorLogWriter = new PrintWriter(System.err);
58  
59    private String delimiter = DEFAULT_DELIMITER;
60    private boolean fullLineDelimiter;
61  
62    public ScriptRunner(Connection connection) {
63      this.connection = connection;
64    }
65  
66    public void setStopOnError(boolean stopOnError) {
67      this.stopOnError = stopOnError;
68    }
69  
70    public void setThrowWarning(boolean throwWarning) {
71      this.throwWarning = throwWarning;
72    }
73  
74    public void setAutoCommit(boolean autoCommit) {
75      this.autoCommit = autoCommit;
76    }
77  
78    public void setSendFullScript(boolean sendFullScript) {
79      this.sendFullScript = sendFullScript;
80    }
81  
82    public void setRemoveCRs(boolean removeCRs) {
83      this.removeCRs = removeCRs;
84    }
85  
86    /**
87     * Sets the escape processing.
88     *
89     * @param escapeProcessing
90     *          the new escape processing
91     *
92     * @since 3.1.1
93     */
94    public void setEscapeProcessing(boolean escapeProcessing) {
95      this.escapeProcessing = escapeProcessing;
96    }
97  
98    public void setLogWriter(PrintWriter logWriter) {
99      this.logWriter = logWriter;
100   }
101 
102   public void setErrorLogWriter(PrintWriter errorLogWriter) {
103     this.errorLogWriter = errorLogWriter;
104   }
105 
106   public void setDelimiter(String delimiter) {
107     this.delimiter = delimiter;
108   }
109 
110   public void setFullLineDelimiter(boolean fullLineDelimiter) {
111     this.fullLineDelimiter = fullLineDelimiter;
112   }
113 
114   public void runScript(Reader reader) {
115     setAutoCommit();
116 
117     try {
118       if (sendFullScript) {
119         executeFullScript(reader);
120       } else {
121         executeLineByLine(reader);
122       }
123     } finally {
124       rollbackConnection();
125     }
126   }
127 
128   private void executeFullScript(Reader reader) {
129     StringBuilder script = new StringBuilder();
130     try {
131       BufferedReader lineReader = new BufferedReader(reader);
132       String line;
133       while ((line = lineReader.readLine()) != null) {
134         script.append(line);
135         script.append(LINE_SEPARATOR);
136       }
137       String command = script.toString();
138       println(command);
139       executeStatement(command);
140       commitConnection();
141     } catch (Exception e) {
142       String message = "Error executing: " + script + ".  Cause: " + e;
143       printlnError(message);
144       throw new RuntimeSqlException(message, e);
145     }
146   }
147 
148   private void executeLineByLine(Reader reader) {
149     StringBuilder command = new StringBuilder();
150     try {
151       BufferedReader lineReader = new BufferedReader(reader);
152       String line;
153       while ((line = lineReader.readLine()) != null) {
154         handleLine(command, line);
155       }
156       commitConnection();
157       checkForMissingLineTerminator(command);
158     } catch (Exception e) {
159       String message = "Error executing: " + command + ".  Cause: " + e;
160       printlnError(message);
161       throw new RuntimeSqlException(message, e);
162     }
163   }
164 
165   /**
166    * @deprecated Since 3.5.4, this method is deprecated. Please close the {@link Connection} outside of this class.
167    */
168   @Deprecated
169   public void closeConnection() {
170     try {
171       connection.close();
172     } catch (Exception e) {
173       // ignore
174     }
175   }
176 
177   private void setAutoCommit() {
178     try {
179       if (autoCommit != connection.getAutoCommit()) {
180         connection.setAutoCommit(autoCommit);
181       }
182     } catch (Throwable t) {
183       throw new RuntimeSqlException("Could not set AutoCommit to " + autoCommit + ". Cause: " + t, t);
184     }
185   }
186 
187   private void commitConnection() {
188     try {
189       if (!connection.getAutoCommit()) {
190         connection.commit();
191       }
192     } catch (Throwable t) {
193       throw new RuntimeSqlException("Could not commit transaction. Cause: " + t, t);
194     }
195   }
196 
197   private void rollbackConnection() {
198     try {
199       if (!connection.getAutoCommit()) {
200         connection.rollback();
201       }
202     } catch (Throwable t) {
203       // ignore
204     }
205   }
206 
207   private void checkForMissingLineTerminator(StringBuilder command) {
208     if (command != null && command.toString().trim().length() > 0) {
209       throw new RuntimeSqlException("Line missing end-of-line terminator (" + delimiter + ") => " + command);
210     }
211   }
212 
213   private void handleLine(StringBuilder command, String line) throws SQLException {
214     String trimmedLine = line.trim();
215     if (lineIsComment(trimmedLine)) {
216       Matcher matcher = DELIMITER_PATTERN.matcher(trimmedLine);
217       if (matcher.find()) {
218         delimiter = matcher.group(5);
219       }
220       println(trimmedLine);
221     } else if (commandReadyToExecute(trimmedLine)) {
222       command.append(line, 0, line.lastIndexOf(delimiter));
223       command.append(LINE_SEPARATOR);
224       println(command);
225       executeStatement(command.toString());
226       command.setLength(0);
227     } else if (trimmedLine.length() > 0) {
228       command.append(line);
229       command.append(LINE_SEPARATOR);
230     }
231   }
232 
233   private boolean lineIsComment(String trimmedLine) {
234     return trimmedLine.startsWith("//") || trimmedLine.startsWith("--");
235   }
236 
237   private boolean commandReadyToExecute(String trimmedLine) {
238     // issue #561 remove anything after the delimiter
239     return !fullLineDelimiter && trimmedLine.contains(delimiter) || fullLineDelimiter && trimmedLine.equals(delimiter);
240   }
241 
242   private void executeStatement(String command) throws SQLException {
243     try (Statement statement = connection.createStatement()) {
244       statement.setEscapeProcessing(escapeProcessing);
245       String sql = command;
246       if (removeCRs) {
247         sql = sql.replace("\r\n", "\n");
248       }
249       try {
250         boolean hasResults = statement.execute(sql);
251         // DO NOT try to 'improve' the condition even if IDE tells you to!
252         // It's important that getUpdateCount() is called here.
253         while (!(!hasResults && statement.getUpdateCount() == -1)) {
254           checkWarnings(statement);
255           printResults(statement, hasResults);
256           hasResults = statement.getMoreResults();
257         }
258       } catch (SQLWarning e) {
259         throw e;
260       } catch (SQLException e) {
261         if (stopOnError) {
262           throw e;
263         }
264         String message = "Error executing: " + command + ".  Cause: " + e;
265         printlnError(message);
266       }
267     }
268   }
269 
270   private void checkWarnings(Statement statement) throws SQLException {
271     if (!throwWarning) {
272       return;
273     }
274     // In Oracle, CREATE PROCEDURE, FUNCTION, etc. returns warning
275     // instead of throwing exception if there is compilation error.
276     SQLWarning warning = statement.getWarnings();
277     if (warning != null) {
278       throw warning;
279     }
280   }
281 
282   private void printResults(Statement statement, boolean hasResults) {
283     if (!hasResults) {
284       return;
285     }
286     try (ResultSet rs = statement.getResultSet()) {
287       ResultSetMetaData md = rs.getMetaData();
288       int cols = md.getColumnCount();
289       for (int i = 0; i < cols; i++) {
290         String name = md.getColumnLabel(i + 1);
291         print(name + "\t");
292       }
293       println("");
294       while (rs.next()) {
295         for (int i = 0; i < cols; i++) {
296           String value = rs.getString(i + 1);
297           print(value + "\t");
298         }
299         println("");
300       }
301     } catch (SQLException e) {
302       printlnError("Error printing results: " + e.getMessage());
303     }
304   }
305 
306   private void print(Object o) {
307     if (logWriter != null) {
308       logWriter.print(o);
309       logWriter.flush();
310     }
311   }
312 
313   private void println(Object o) {
314     if (logWriter != null) {
315       logWriter.println(o);
316       logWriter.flush();
317     }
318   }
319 
320   private void printlnError(Object o) {
321     if (errorLogWriter != null) {
322       errorLogWriter.println(o);
323       errorLogWriter.flush();
324     }
325   }
326 
327 }