QueryAdapterFactory.java
package com.github.jonasrutishauser.transactional.event.core.store;
import static java.util.Arrays.asList;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.HashSet;
import java.util.Set;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.Instance;
import jakarta.enterprise.inject.TransientReference;
import jakarta.inject.Inject;
import javax.sql.DataSource;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import com.github.jonasrutishauser.transactional.event.api.Events;
import com.github.jonasrutishauser.transactional.event.api.store.QueryAdapter;
@ApplicationScoped
class QueryAdapterFactory {
private static final Logger LOGGER = LogManager.getLogger();
private final QueryAdapter queryAdapter;
public QueryAdapter getQueryAdapter() {
return queryAdapter;
}
QueryAdapterFactory() {
queryAdapter = null;
}
QueryAdapterFactory(DataSource dataSource) throws SQLException {
try (Connection connection = dataSource.getConnection()) {
queryAdapter = getQueryAdapter(connection);
}
}
@Inject
QueryAdapterFactory(Instance<QueryAdapter> cdiInstance, @TransientReference @Events DataSource dataSource) {
if (cdiInstance.isUnsatisfied()) {
try (Connection connection = dataSource.getConnection()) {
queryAdapter = getQueryAdapter(connection);
} catch (SQLException e) {
throw new IllegalStateException(e);
}
} else {
queryAdapter = cdiInstance.get();
}
}
private static QueryAdapter getQueryAdapter(Connection connection) throws SQLException {
QueryAdapter queryAdapter;
String productName = connection.getMetaData().getDatabaseProductName();
if (productName.contains("Oracle")) {
queryAdapter = new OracleQueryAdapter();
} else if (productName.contains("MariaDB")) {
queryAdapter = new MariaDBQueryAdapter();
} else if (productName.contains("PostgreSQL") || productName.contains("MySQL")) {
queryAdapter = new LimitQueryAdapter();
} else {
Set<String> keywords = new HashSet<>(asList(connection.getMetaData().getSQLKeywords().split(",")));
if (keywords.contains("skip") && keywords.contains("locked")) {
queryAdapter = new SkipLockedQueryAdapter();
} else {
queryAdapter = new SimpleQueryAdapter();
}
}
LOGGER.debug(() -> "DB '" + productName + "' uses " + queryAdapter.getClass().getSimpleName());
return queryAdapter;
}
private static class SimpleQueryAdapter implements QueryAdapter {
protected static final String LIMIT_EXPRESSION = "\\{LIMIT ([^}]+)\\}";
@Override
public String fixLimits(String sql) {
return sql;
}
@Override
public String addSkipLocked(String sql) {
return sql;
}
}
private static class SkipLockedQueryAdapter extends SimpleQueryAdapter {
@Override
public String addSkipLocked(String sql) {
return sql.replaceAll(LIMIT_EXPRESSION, "").replace("FOR UPDATE", "FOR UPDATE SKIP LOCKED");
}
}
private static class OracleQueryAdapter extends SkipLockedQueryAdapter {
@Override
public String addSkipLocked(String sql, int maxRows) {
return super.addSkipLocked(sql.replace("SELECT ", "SELECT /*+ FIRST_ROWS(" + maxRows + ") */ "));
}
@Override
public String fixLimits(String sql) {
return sql.replaceAll(LIMIT_EXPRESSION, "AND rownum <= $1");
}
}
private static class MariaDBQueryAdapter extends SimpleQueryAdapter {
@Override
public String fixLimits(String sql) {
return sql.replaceAll(LIMIT_EXPRESSION, "LIMIT $1");
}
}
private static class LimitQueryAdapter extends SkipLockedQueryAdapter {
@Override
public String fixLimits(String sql) {
return sql.replaceAll(LIMIT_EXPRESSION, "LIMIT $1");
}
}
}