Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ package org.apache.texera.amber.operator.source.sql.mysql
import java.sql.{Connection, DriverManager, SQLException}

object MySQLConnUtil {

// Builds the JDBC URL.
private[mysql] def buildJdbcUrl(host: String, port: String, database: String): String =
"jdbc:mysql://" + host + ":" + port + "/" + database + "?autoReconnect=true&useSSL=true"

@throws[SQLException]
def connect(
host: String,
Expand All @@ -30,8 +35,7 @@ object MySQLConnUtil {
username: String,
password: String
): Connection = {
val url =
"jdbc:mysql://" + host + ":" + port + "/" + database + "?autoReconnect=true&useSSL=true"
val url = buildJdbcUrl(host, port, database)
val connection = DriverManager.getConnection(url, username, password)
// set to readonly to improve efficiency
connection.setReadOnly(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ package org.apache.texera.amber.operator.source.sql.postgresql
import java.sql.{Connection, DriverManager, SQLException}

object PostgreSQLConnUtil {

// Builds the JDBC URL.
private[postgresql] def buildJdbcUrl(host: String, port: String, database: String): String =
"jdbc:postgresql://" + host + ":" + port + "/" + database

@throws[SQLException]
def connect(
host: String,
Expand All @@ -30,7 +35,7 @@ object PostgreSQLConnUtil {
username: String,
password: String
): Connection = {
val url = "jdbc:postgresql://" + host + ":" + port + "/" + database
val url = buildJdbcUrl(host, port, database)
val connection = DriverManager.getConnection(url, username, password)
// set to readonly to improve efficiency
connection.setReadOnly(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,241 +20,71 @@
package org.apache.texera.amber.operator.source.sql.mysql

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.BeforeAndAfterAll

import java.lang.reflect.{InvocationHandler, Method, Proxy}
import java.sql.{Connection, Driver, DriverManager, DriverPropertyInfo, SQLException}
import java.util.Properties
import java.util.logging.Logger
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
class MySQLConnUtilSpec extends AnyFlatSpec {

class MySQLConnUtilSpec extends AnyFlatSpec with BeforeAndAfterAll {
// Strategy: same as PostgreSQLConnUtilSpec. Pin the JDBC URL composition
// (the only application-logic in this util) without a real DB.

// ---------------------------------------------------------------------------
// Strategy — same capturing-driver pattern as PostgreSQLConnUtilSpec.
// The MySQL driver may or may not be present transitively, so we
// proactively deregister anything that claims jdbc:mysql: and swap in a
// capturing driver that records each URL and returns a Proxy-backed
// Connection so the production code can call `setReadOnly(true)`.
// URL composition - host/port/database
// ---------------------------------------------------------------------------

private object CapturingMySQLDriver extends Driver {
val seenUrls: ArrayBuffer[String] = ArrayBuffer.empty
val seenProps: ArrayBuffer[Properties] = ArrayBuffer.empty
val readOnlyCalls: ArrayBuffer[Boolean] = ArrayBuffer.empty

override def connect(url: String, info: Properties): Connection = {
if (!acceptsURL(url)) return null
seenUrls += url
seenProps += info
Proxy
.newProxyInstance(
getClass.getClassLoader,
Array(classOf[Connection]),
new InvocationHandler {
override def invoke(p: Any, m: Method, args: Array[AnyRef]): AnyRef =
m.getName match {
case "setReadOnly" =>
readOnlyCalls += args(0).asInstanceOf[java.lang.Boolean].booleanValue()
null
case "equals" => java.lang.Boolean.valueOf(p eq args(0))
case "hashCode" => java.lang.Integer.valueOf(System.identityHashCode(p))
case "toString" =>
"CapturingMySQLDriver.StubConnection@" + System.identityHashCode(p)
case "isWrapperFor" => java.lang.Boolean.FALSE
case "close" => null
case _ => null
}
}
)
.asInstanceOf[Connection]
}
override def acceptsURL(url: String): Boolean =
url != null && url.startsWith("jdbc:mysql:")
override def getPropertyInfo(url: String, info: Properties): Array[DriverPropertyInfo] =
Array.empty
override def getMajorVersion: Int = 1
override def getMinorVersion: Int = 0
override def jdbcCompliant(): Boolean = false
override def getParentLogger: Logger = Logger.getLogger("test-mysql-capturing")
}

private val savedRealDrivers: ArrayBuffer[Driver] = ArrayBuffer.empty

private def safeAcceptsURL(d: Driver, url: String): Boolean =
try d.acceptsURL(url)
catch { case _: Throwable => false }

override protected def beforeAll(): Unit = {
super.beforeAll()
// The probe URL mirrors the exact shape `MySQLConnUtil.connect`
// constructs (`jdbc:mysql://{host}:{port}/{database}?…`), including
// the canonical query parameters. A permissive third-party driver
// that returns `false` on a stripped-down probe but `true` on the
// real URL would otherwise slip past us.
try {
val others = DriverManager.getDrivers.asScala.toList.filter { d =>
d != CapturingMySQLDriver && safeAcceptsURL(
d,
"jdbc:mysql://probe-host:3306/probe-db?autoReconnect=true&useSSL=true"
)
}
others.foreach { d =>
savedRealDrivers += d
DriverManager.deregisterDriver(d)
}
DriverManager.registerDriver(CapturingMySQLDriver)
} catch {
case t: Throwable =>
savedRealDrivers.foreach { d =>
try DriverManager.registerDriver(d)
catch { case _: Throwable => () }
}
throw t
}
}

override protected def afterAll(): Unit = {
try {
try DriverManager.deregisterDriver(CapturingMySQLDriver)
catch { case _: Throwable => () }
savedRealDrivers.foreach { d =>
try DriverManager.registerDriver(d)
catch { case _: Throwable => () }
}
} finally {
super.afterAll()
}
}

private def clearCapture(): Unit = {
CapturingMySQLDriver.seenUrls.clear()
CapturingMySQLDriver.seenProps.clear()
CapturingMySQLDriver.readOnlyCalls.clear()
}

// ---------------------------------------------------------------------------
// URL composition — host/port/database
// ---------------------------------------------------------------------------

"MySQLConnUtil.connect" should
"build a JDBC URL of the form jdbc:mysql://{host}:{port}/{database}?…" in {
clearCapture()
val conn = MySQLConnUtil.connect("host-m", "3306", "db-m", "u", "p")
assert(conn != null)
assert(CapturingMySQLDriver.seenUrls.size == 1)
assert(CapturingMySQLDriver.seenUrls.head.startsWith("jdbc:mysql://host-m:3306/db-m"))
"MySQLConnUtil.buildJdbcUrl" should
"build a URL of the form jdbc:mysql://{host}:{port}/{database}?..." in {
assert(
MySQLConnUtil
.buildJdbcUrl("host-m", "3306", "db-m")
.startsWith("jdbc:mysql://host-m:3306/db-m")
)
}

it should "interpolate distinct host/port/database values into the URL" in {
clearCapture()
MySQLConnUtil.connect("host-1", "3306", "db-1", "u", "p")
assert(CapturingMySQLDriver.seenUrls.head.startsWith("jdbc:mysql://host-1:3306/db-1"))
clearCapture()
MySQLConnUtil.connect("host-2", "33060", "db-2", "u", "p")
assert(CapturingMySQLDriver.seenUrls.head.startsWith("jdbc:mysql://host-2:33060/db-2"))
assert(
MySQLConnUtil
.buildJdbcUrl("host-1", "3306", "db-1")
.startsWith("jdbc:mysql://host-1:3306/db-1")
)
assert(
MySQLConnUtil
.buildJdbcUrl("host-2", "33060", "db-2")
.startsWith("jdbc:mysql://host-2:33060/db-2")
)
}

it should "place host BEFORE port" in {
clearCapture()
MySQLConnUtil.connect("a", "1", "x", "u", "p")
val url = CapturingMySQLDriver.seenUrls.head
val url = MySQLConnUtil.buildJdbcUrl("a", "1", "x")
assert(url.contains("//a:1/"))
assert(!url.contains("//1:a/"))
}

// ---------------------------------------------------------------------------
// Query parameters autoReconnect=true and useSSL=true must be present
// Query parameters - autoReconnect=true and useSSL=true must be present
// ---------------------------------------------------------------------------

it should "include the `autoReconnect=true` query parameter" in {
clearCapture()
MySQLConnUtil.connect("h", "3306", "db", "u", "p")
val url = CapturingMySQLDriver.seenUrls.head
val url = MySQLConnUtil.buildJdbcUrl("h", "3306", "db")
assert(url.contains("autoReconnect=true"), s"URL must include autoReconnect=true, got: $url")
}

it should "include the `useSSL=true` query parameter (TLS contract)" in {
clearCapture()
MySQLConnUtil.connect("h", "3306", "db", "u", "p")
val url = CapturingMySQLDriver.seenUrls.head
val url = MySQLConnUtil.buildJdbcUrl("h", "3306", "db")
assert(url.contains("useSSL=true"), s"URL must include useSSL=true (TLS), got: $url")
}

it should "use the canonical `?…&…` separator pattern" in {
clearCapture()
MySQLConnUtil.connect("h", "3306", "db", "u", "p")
val url = CapturingMySQLDriver.seenUrls.head
it should "use the canonical `?...&...` separator pattern" in {
assert(
url == "jdbc:mysql://h:3306/db?autoReconnect=true&useSSL=true",
s"URL must match canonical pattern, got: $url"
MySQLConnUtil.buildJdbcUrl(
"h",
"3306",
"db"
) == "jdbc:mysql://h:3306/db?autoReconnect=true&useSSL=true"
)
}

it should "use the `mysql` JDBC subprotocol (not e.g. `postgresql`)" in {
clearCapture()
MySQLConnUtil.connect("h", "3306", "db", "u", "p")
val url = CapturingMySQLDriver.seenUrls.head
val url = MySQLConnUtil.buildJdbcUrl("h", "3306", "db")
assert(url.startsWith("jdbc:mysql://"))
assert(!url.contains("jdbc:postgresql:"))
}

// ---------------------------------------------------------------------------
// Credentials propagation
// ---------------------------------------------------------------------------

it should "pass username and password through DriverManager properties" in {
clearCapture()
MySQLConnUtil.connect("h", "3306", "db", "the-user", "the-pass")
val props = CapturingMySQLDriver.seenProps.head
assert(props.getProperty("user") == "the-user")
assert(props.getProperty("password") == "the-pass")
}

// ---------------------------------------------------------------------------
// setReadOnly(true) — pinned via the captured proxy (parity with PG spec)
// ---------------------------------------------------------------------------

it should "flip the returned Connection to read-only (query-efficiency contract)" in {
clearCapture()
MySQLConnUtil.connect("h", "3306", "db", "u", "p")
assert(CapturingMySQLDriver.readOnlyCalls == ArrayBuffer(true))
}

// ---------------------------------------------------------------------------
// SQLException propagation when the driver throws
// ---------------------------------------------------------------------------

it should "propagate SQLException when the driver throws" in {
val throwingDriver = new Driver {
override def acceptsURL(url: String): Boolean =
url != null && url.startsWith("jdbc:mysql:")
// Follow the JDBC contract: return `null` if the URL isn't ours
// and throw only on a matching URL — keeps the helper from
// interfering with `DriverManager.getConnection` calls for any
// other scheme that might happen during the suite.
override def connect(url: String, info: Properties): Connection = {
if (!acceptsURL(url)) return null
throw new SQLException("forced-fail-for-test")
}
override def getPropertyInfo(url: String, info: Properties) =
Array.empty[DriverPropertyInfo]
override def getMajorVersion: Int = 99
override def getMinorVersion: Int = 0
override def jdbcCompliant(): Boolean = false
override def getParentLogger: Logger = Logger.getLogger("test-mysql-throwing")
}
DriverManager.deregisterDriver(CapturingMySQLDriver)
DriverManager.registerDriver(throwingDriver)
try {
val ex = intercept[SQLException] {
MySQLConnUtil.connect("h", "3306", "db", "u", "p")
}
assert(ex.getMessage.contains("forced-fail-for-test"))
} finally {
DriverManager.deregisterDriver(throwingDriver)
DriverManager.registerDriver(CapturingMySQLDriver)
}
}
}
Loading
Loading