ScriptRunner.java

  1. /*
  2.  *    Copyright 2009-2023 the original author or authors.
  3.  *
  4.  *    Licensed under the Apache License, Version 2.0 (the "License");
  5.  *    you may not use this file except in compliance with the License.
  6.  *    You may obtain a copy of the License at
  7.  *
  8.  *       https://www.apache.org/licenses/LICENSE-2.0
  9.  *
  10.  *    Unless required by applicable law or agreed to in writing, software
  11.  *    distributed under the License is distributed on an "AS IS" BASIS,
  12.  *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13.  *    See the License for the specific language governing permissions and
  14.  *    limitations under the License.
  15.  */
  16. package org.apache.ibatis.jdbc;

  17. import java.io.BufferedReader;
  18. import java.io.PrintWriter;
  19. import java.io.Reader;
  20. import java.sql.Connection;
  21. import java.sql.ResultSet;
  22. import java.sql.ResultSetMetaData;
  23. import java.sql.SQLException;
  24. import java.sql.SQLWarning;
  25. import java.sql.Statement;
  26. import java.util.regex.Matcher;
  27. import java.util.regex.Pattern;

  28. /**
  29.  * This is an internal testing utility.<br>
  30.  * You are welcome to use this class for your own purposes,<br>
  31.  * but if there is some feature/enhancement you need for your own usage,<br>
  32.  * please make and modify your own copy instead of sending us an enhancement request.<br>
  33.  *
  34.  * @author Clinton Begin
  35.  */
  36. public class ScriptRunner {

  37.   private static final String LINE_SEPARATOR = System.lineSeparator();

  38.   private static final String DEFAULT_DELIMITER = ";";

  39.   private static final Pattern DELIMITER_PATTERN = Pattern
  40.       .compile("^\\s*((--)|(//))?\\s*(//)?\\s*@DELIMITER\\s+([^\\s]+)", Pattern.CASE_INSENSITIVE);

  41.   private final Connection connection;

  42.   private boolean stopOnError;
  43.   private boolean throwWarning;
  44.   private boolean autoCommit;
  45.   private boolean sendFullScript;
  46.   private boolean removeCRs;
  47.   private boolean escapeProcessing = true;

  48.   private PrintWriter logWriter = new PrintWriter(System.out);
  49.   private PrintWriter errorLogWriter = new PrintWriter(System.err);

  50.   private String delimiter = DEFAULT_DELIMITER;
  51.   private boolean fullLineDelimiter;

  52.   public ScriptRunner(Connection connection) {
  53.     this.connection = connection;
  54.   }

  55.   public void setStopOnError(boolean stopOnError) {
  56.     this.stopOnError = stopOnError;
  57.   }

  58.   public void setThrowWarning(boolean throwWarning) {
  59.     this.throwWarning = throwWarning;
  60.   }

  61.   public void setAutoCommit(boolean autoCommit) {
  62.     this.autoCommit = autoCommit;
  63.   }

  64.   public void setSendFullScript(boolean sendFullScript) {
  65.     this.sendFullScript = sendFullScript;
  66.   }

  67.   public void setRemoveCRs(boolean removeCRs) {
  68.     this.removeCRs = removeCRs;
  69.   }

  70.   /**
  71.    * Sets the escape processing.
  72.    *
  73.    * @param escapeProcessing
  74.    *          the new escape processing
  75.    *
  76.    * @since 3.1.1
  77.    */
  78.   public void setEscapeProcessing(boolean escapeProcessing) {
  79.     this.escapeProcessing = escapeProcessing;
  80.   }

  81.   public void setLogWriter(PrintWriter logWriter) {
  82.     this.logWriter = logWriter;
  83.   }

  84.   public void setErrorLogWriter(PrintWriter errorLogWriter) {
  85.     this.errorLogWriter = errorLogWriter;
  86.   }

  87.   public void setDelimiter(String delimiter) {
  88.     this.delimiter = delimiter;
  89.   }

  90.   public void setFullLineDelimiter(boolean fullLineDelimiter) {
  91.     this.fullLineDelimiter = fullLineDelimiter;
  92.   }

  93.   public void runScript(Reader reader) {
  94.     setAutoCommit();

  95.     try {
  96.       if (sendFullScript) {
  97.         executeFullScript(reader);
  98.       } else {
  99.         executeLineByLine(reader);
  100.       }
  101.     } finally {
  102.       rollbackConnection();
  103.     }
  104.   }

  105.   private void executeFullScript(Reader reader) {
  106.     StringBuilder script = new StringBuilder();
  107.     try {
  108.       BufferedReader lineReader = new BufferedReader(reader);
  109.       String line;
  110.       while ((line = lineReader.readLine()) != null) {
  111.         script.append(line);
  112.         script.append(LINE_SEPARATOR);
  113.       }
  114.       String command = script.toString();
  115.       println(command);
  116.       executeStatement(command);
  117.       commitConnection();
  118.     } catch (Exception e) {
  119.       String message = "Error executing: " + script + ".  Cause: " + e;
  120.       printlnError(message);
  121.       throw new RuntimeSqlException(message, e);
  122.     }
  123.   }

  124.   private void executeLineByLine(Reader reader) {
  125.     StringBuilder command = new StringBuilder();
  126.     try {
  127.       BufferedReader lineReader = new BufferedReader(reader);
  128.       String line;
  129.       while ((line = lineReader.readLine()) != null) {
  130.         handleLine(command, line);
  131.       }
  132.       commitConnection();
  133.       checkForMissingLineTerminator(command);
  134.     } catch (Exception e) {
  135.       String message = "Error executing: " + command + ".  Cause: " + e;
  136.       printlnError(message);
  137.       throw new RuntimeSqlException(message, e);
  138.     }
  139.   }

  140.   /**
  141.    * @deprecated Since 3.5.4, this method is deprecated. Please close the {@link Connection} outside of this class.
  142.    */
  143.   @Deprecated
  144.   public void closeConnection() {
  145.     try {
  146.       connection.close();
  147.     } catch (Exception e) {
  148.       // ignore
  149.     }
  150.   }

  151.   private void setAutoCommit() {
  152.     try {
  153.       if (autoCommit != connection.getAutoCommit()) {
  154.         connection.setAutoCommit(autoCommit);
  155.       }
  156.     } catch (Throwable t) {
  157.       throw new RuntimeSqlException("Could not set AutoCommit to " + autoCommit + ". Cause: " + t, t);
  158.     }
  159.   }

  160.   private void commitConnection() {
  161.     try {
  162.       if (!connection.getAutoCommit()) {
  163.         connection.commit();
  164.       }
  165.     } catch (Throwable t) {
  166.       throw new RuntimeSqlException("Could not commit transaction. Cause: " + t, t);
  167.     }
  168.   }

  169.   private void rollbackConnection() {
  170.     try {
  171.       if (!connection.getAutoCommit()) {
  172.         connection.rollback();
  173.       }
  174.     } catch (Throwable t) {
  175.       // ignore
  176.     }
  177.   }

  178.   private void checkForMissingLineTerminator(StringBuilder command) {
  179.     if (command != null && command.toString().trim().length() > 0) {
  180.       throw new RuntimeSqlException("Line missing end-of-line terminator (" + delimiter + ") => " + command);
  181.     }
  182.   }

  183.   private void handleLine(StringBuilder command, String line) throws SQLException {
  184.     String trimmedLine = line.trim();
  185.     if (lineIsComment(trimmedLine)) {
  186.       Matcher matcher = DELIMITER_PATTERN.matcher(trimmedLine);
  187.       if (matcher.find()) {
  188.         delimiter = matcher.group(5);
  189.       }
  190.       println(trimmedLine);
  191.     } else if (commandReadyToExecute(trimmedLine)) {
  192.       command.append(line, 0, line.lastIndexOf(delimiter));
  193.       command.append(LINE_SEPARATOR);
  194.       println(command);
  195.       executeStatement(command.toString());
  196.       command.setLength(0);
  197.     } else if (trimmedLine.length() > 0) {
  198.       command.append(line);
  199.       command.append(LINE_SEPARATOR);
  200.     }
  201.   }

  202.   private boolean lineIsComment(String trimmedLine) {
  203.     return trimmedLine.startsWith("//") || trimmedLine.startsWith("--");
  204.   }

  205.   private boolean commandReadyToExecute(String trimmedLine) {
  206.     // issue #561 remove anything after the delimiter
  207.     return !fullLineDelimiter && trimmedLine.contains(delimiter) || fullLineDelimiter && trimmedLine.equals(delimiter);
  208.   }

  209.   private void executeStatement(String command) throws SQLException {
  210.     try (Statement statement = connection.createStatement()) {
  211.       statement.setEscapeProcessing(escapeProcessing);
  212.       String sql = command;
  213.       if (removeCRs) {
  214.         sql = sql.replace("\r\n", "\n");
  215.       }
  216.       try {
  217.         boolean hasResults = statement.execute(sql);
  218.         // DO NOT try to 'improve' the condition even if IDE tells you to!
  219.         // It's important that getUpdateCount() is called here.
  220.         while (!(!hasResults && statement.getUpdateCount() == -1)) {
  221.           checkWarnings(statement);
  222.           printResults(statement, hasResults);
  223.           hasResults = statement.getMoreResults();
  224.         }
  225.       } catch (SQLWarning e) {
  226.         throw e;
  227.       } catch (SQLException e) {
  228.         if (stopOnError) {
  229.           throw e;
  230.         }
  231.         String message = "Error executing: " + command + ".  Cause: " + e;
  232.         printlnError(message);
  233.       }
  234.     }
  235.   }

  236.   private void checkWarnings(Statement statement) throws SQLException {
  237.     if (!throwWarning) {
  238.       return;
  239.     }
  240.     // In Oracle, CREATE PROCEDURE, FUNCTION, etc. returns warning
  241.     // instead of throwing exception if there is compilation error.
  242.     SQLWarning warning = statement.getWarnings();
  243.     if (warning != null) {
  244.       throw warning;
  245.     }
  246.   }

  247.   private void printResults(Statement statement, boolean hasResults) {
  248.     if (!hasResults) {
  249.       return;
  250.     }
  251.     try (ResultSet rs = statement.getResultSet()) {
  252.       ResultSetMetaData md = rs.getMetaData();
  253.       int cols = md.getColumnCount();
  254.       for (int i = 0; i < cols; i++) {
  255.         String name = md.getColumnLabel(i + 1);
  256.         print(name + "\t");
  257.       }
  258.       println("");
  259.       while (rs.next()) {
  260.         for (int i = 0; i < cols; i++) {
  261.           String value = rs.getString(i + 1);
  262.           print(value + "\t");
  263.         }
  264.         println("");
  265.       }
  266.     } catch (SQLException e) {
  267.       printlnError("Error printing results: " + e.getMessage());
  268.     }
  269.   }

  270.   private void print(Object o) {
  271.     if (logWriter != null) {
  272.       logWriter.print(o);
  273.       logWriter.flush();
  274.     }
  275.   }

  276.   private void println(Object o) {
  277.     if (logWriter != null) {
  278.       logWriter.println(o);
  279.       logWriter.flush();
  280.     }
  281.   }

  282.   private void printlnError(Object o) {
  283.     if (errorLogWriter != null) {
  284.       errorLogWriter.println(o);
  285.       errorLogWriter.flush();
  286.     }
  287.   }

  288. }