From f62ffd0f44128757b137137c7d06daa97ccc9bb3 Mon Sep 17 00:00:00 2001 From: vincenthql <229323147@qq.com> Date: Sun, 22 Mar 2026 08:28:20 +0800 Subject: [PATCH] feat: implement adb reverse --- .../kotlin/com/flyfishxu/kadb/Kadb.kt | 23 +++ .../com/flyfishxu/kadb/core/AdbConnection.kt | 144 +++++++++++++++++- .../com/flyfishxu/kadb/reverse/ReverseCore.kt | 117 ++++++++++++++ 3 files changed, 283 insertions(+), 1 deletion(-) create mode 100644 kadb/src/commonMain/kotlin/com/flyfishxu/kadb/reverse/ReverseCore.kt diff --git a/kadb/src/commonMain/kotlin/com/flyfishxu/kadb/Kadb.kt b/kadb/src/commonMain/kotlin/com/flyfishxu/kadb/Kadb.kt index 3449327..48b0022 100644 --- a/kadb/src/commonMain/kotlin/com/flyfishxu/kadb/Kadb.kt +++ b/kadb/src/commonMain/kotlin/com/flyfishxu/kadb/Kadb.kt @@ -5,6 +5,8 @@ import com.flyfishxu.kadb.cert.platform.defaultDeviceName import com.flyfishxu.kadb.core.AdbConnection import com.flyfishxu.kadb.forwarding.TcpForwarder import com.flyfishxu.kadb.pair.PairingConnectionCtx +import com.flyfishxu.kadb.reverse.AdbReverse +import com.flyfishxu.kadb.reverse.AdbReverseRule import com.flyfishxu.kadb.shell.AdbShellResponse import com.flyfishxu.kadb.shell.AdbShellStream import com.flyfishxu.kadb.stream.AdbStream @@ -27,6 +29,7 @@ class Kadb( ) : AutoCloseable { private var connection: Pair? = null + private val adbReverse by lazy { AdbReverse(this) } fun connectionCheck(): Boolean = connection?.second?.isOpen == true @@ -196,6 +199,26 @@ class Kadb( return forwarder } + fun reverseForward(device: String, host: String, noRebind: Boolean = false) { + adbReverse.create(device, host, noRebind) + } + + fun reverseForward(devicePort: Int, hostPort: Int, noRebind: Boolean = false) { + reverseForward("tcp:$devicePort", "tcp:$hostPort", noRebind) + } + + fun reverseKillForward(device: String) { + adbReverse.remove(device) + } + + fun reverseKillAllForwards() { + adbReverse.removeAll() + } + + fun reverseListForwards(): List { + return adbReverse.list() + } + private fun restartAdb(destination: String): String { this.open(destination).use { stream -> return stream.source.readUntil('\n'.code.toByte()).readString(Charsets.UTF_8) diff --git a/kadb/src/commonMain/kotlin/com/flyfishxu/kadb/core/AdbConnection.kt b/kadb/src/commonMain/kotlin/com/flyfishxu/kadb/core/AdbConnection.kt index 0a6794e..02e7bed 100644 --- a/kadb/src/commonMain/kotlin/com/flyfishxu/kadb/core/AdbConnection.kt +++ b/kadb/src/commonMain/kotlin/com/flyfishxu/kadb/core/AdbConnection.kt @@ -2,6 +2,7 @@ package com.flyfishxu.kadb.core import com.flyfishxu.kadb.cert.AdbKeyPair import com.flyfishxu.kadb.cert.platform.defaultDeviceName +import com.flyfishxu.kadb.debug.log import com.flyfishxu.kadb.pair.SslUtils import com.flyfishxu.kadb.queue.AdbMessageQueue import com.flyfishxu.kadb.stream.AdbStream @@ -19,11 +20,19 @@ import java.nio.ByteBuffer import java.nio.ByteOrder import java.security.interfaces.RSAPublicKey import java.util.* +import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.TimeUnit import javax.net.ssl.SSLProtocolException import kotlin.Throws import kotlin.io.encoding.Base64 import kotlin.io.encoding.ExperimentalEncodingApi +import kotlin.concurrent.thread +import okio.BufferedSink +import okio.Source +import okio.buffer +import okio.sink +import okio.source +import java.net.Socket internal class AdbConnection internal constructor( adbReader: AdbReader, @@ -36,6 +45,12 @@ internal class AdbConnection internal constructor( private val random = Random() private val messageQueue = AdbMessageQueue(adbReader) + private var reverseThread: Thread? = null + private val reverseSessionThreads = ConcurrentLinkedQueue() + + init { + startReverseBridgeLoop() + } @Throws(IOException::class) fun open(destination: String): AdbStream { @@ -58,7 +73,11 @@ internal class AdbConnection internal constructor( } private fun newId(): Int { - return random.nextInt() + var id: Int + do { + id = random.nextInt() + } while (id == 0) + return id } @TestOnly @@ -68,6 +87,13 @@ internal class AdbConnection internal constructor( override fun close() { try { + reverseThread?.interrupt() + reverseThread = null + while (true) { + val t = reverseSessionThreads.poll() ?: break + t.interrupt() + } + runCatching { messageQueue.stopListening(REVERSE_LISTENER_ID) } messageQueue.close() adbWriter.close() closeable?.close() @@ -75,7 +101,94 @@ internal class AdbConnection internal constructor( } } + private fun startReverseBridgeLoop() { + messageQueue.startListening(REVERSE_LISTENER_ID) + reverseThread = thread(name = "kadb-reverse-accept") { + while (!Thread.currentThread().isInterrupted) { + try { + val openMessage = messageQueue.take(REVERSE_LISTENER_ID, AdbProtocol.CMD_OPEN) + handleIncomingOpen(openMessage) + } catch (_: InterruptedException) { + return@thread + } catch (t: Throwable) { + if (!Thread.currentThread().isInterrupted) { + log { "reverse bridge loop error: ${t.message}" } + } + } + } + } + } + + private fun handleIncomingOpen(message: AdbMessage) { + val destination = extractOpenDestination(message) ?: run { + adbWriter.writeClose(0, message.arg0) + return + } + + val target = parseReverseTcpTarget(destination) ?: run { + adbWriter.writeClose(0, message.arg0) + return + } + + val localSocket = runCatching { Socket(target.host, target.port) }.getOrElse { + adbWriter.writeClose(0, message.arg0) + return + } + + val localId = newId() + messageQueue.startListening(localId) + adbWriter.writeOkay(localId, message.arg0) + + val adbStream = AdbStream(messageQueue, adbWriter, maxPayloadSize, localId, message.arg0) + + val socketSource = localSocket.getInputStream().source() + val socketSink = localSocket.getOutputStream().sink().buffer() + + val readerThread = thread(name = "kadb-reverse-local-to-device") { + try { + bridge(socketSource, adbStream.sink) + } finally { + reverseSessionThreads.remove(Thread.currentThread()) + } + } + reverseSessionThreads.add(readerThread) + + val writerThread = thread(name = "kadb-reverse-device-to-local") { + try { + bridge(adbStream.source, socketSink) + } finally { + runCatching { adbStream.close() } + runCatching { localSocket.close() } + readerThread.interrupt() + reverseSessionThreads.remove(Thread.currentThread()) + } + } + reverseSessionThreads.add(writerThread) + } + + private fun extractOpenDestination(message: AdbMessage): String? { + if (message.payloadLength <= 0) return null + val bytes = message.payload + val endExclusive = if (bytes[message.payloadLength - 1].toInt() == 0) { + message.payloadLength - 1 + } else { + message.payloadLength + } + if (endExclusive <= 0) return null + return String(bytes, 0, endExclusive) + } + + private fun bridge(source: Source, sink: BufferedSink) { + while (!Thread.currentThread().isInterrupted) { + val read = runCatching { source.read(sink.buffer, 256) }.getOrElse { return } + if (read < 0) return + runCatching { sink.flush() }.getOrElse { return } + } + } + companion object { + private const val REVERSE_LISTENER_ID = 0 + suspend fun connect( host: String, port: Int, @@ -182,6 +295,35 @@ internal class AdbConnection internal constructor( } } +internal data class ReverseTcpTarget( + val host: String, + val port: Int, +) + +internal fun parseReverseTcpTarget(destination: String): ReverseTcpTarget? { + if (!destination.startsWith("tcp:")) return null + val raw = destination.removePrefix("tcp:") + if (raw.isBlank()) return null + + val segments = raw.split(':') + return when (segments.size) { + 1 -> { + val port = segments[0].toIntOrNull() ?: return null + if (port !in 1..65535) return null + ReverseTcpTarget(host = "127.0.0.1", port = port) + } + + 2 -> { + val host = segments[0].ifBlank { return null } + val port = segments[1].toIntOrNull() ?: return null + if (port !in 1..65535) return null + ReverseTcpTarget(host = host, port = port) + } + + else -> null + } +} + /*** ADB RSA Public Key Transformation Section ***/ private const val KEY_LENGTH_BITS = 2048 diff --git a/kadb/src/commonMain/kotlin/com/flyfishxu/kadb/reverse/ReverseCore.kt b/kadb/src/commonMain/kotlin/com/flyfishxu/kadb/reverse/ReverseCore.kt new file mode 100644 index 0000000..d8c586b --- /dev/null +++ b/kadb/src/commonMain/kotlin/com/flyfishxu/kadb/reverse/ReverseCore.kt @@ -0,0 +1,117 @@ +package com.flyfishxu.kadb.reverse + +import com.flyfishxu.kadb.Kadb + +data class AdbReverseRule( + val device: String, + val host: String, +) + +internal abstract class BaseReverse { + protected fun createRule(device: String, host: String, noRebind: Boolean = false) { + execute(buildReverseForwardDestination(device, host, noRebind)) + } + + protected fun removeRule(device: String) { + execute(buildReverseKillDestination(device)) + } + + protected fun removeAllRules() { + execute(buildReverseKillAllDestination()) + } + + protected fun listRules(): List { + return parseReverseListOutput(execute(buildReverseListDestination())) + } + + protected abstract fun execute(destination: String): String +} + +internal class AdbReverse( + private val kadb: Kadb, +) : BaseReverse() { + fun create(device: String, host: String, noRebind: Boolean = false) { + createRule(device, host, noRebind) + } + + fun remove(device: String) { + removeRule(device) + } + + fun removeAll() { + removeAllRules() + } + + fun list(): List { + return listRules() + } + + override fun execute(destination: String): String { + return kadb.open(destination).use { stream -> + val output = stream.source.readUtf8() + val response = parseSmartSocketResponse(output) + check(response.status != SmartSocketStatus.FAIL) { "Reverse command failed: ${response.payload}" } + response.payload + } + } +} + +internal fun buildReverseForwardDestination(device: String, host: String, noRebind: Boolean = false): String { + require(device.isNotBlank()) { "device must not be blank" } + require(host.isNotBlank()) { "host must not be blank" } + val prefix = if (noRebind) "reverse:forward:norebind:" else "reverse:forward:" + return "$prefix$device;$host" +} + +internal fun buildReverseKillDestination(device: String): String { + require(device.isNotBlank()) { "device must not be blank" } + return "reverse:killforward:$device" +} + +internal fun buildReverseKillAllDestination(): String = "reverse:killforward-all" + +internal fun buildReverseListDestination(): String = "reverse:list-forward" + +internal fun parseReverseListOutput(output: String): List { + return output + .lineSequence() + .map { it.trim() } + .filter { it.isNotBlank() } + .mapNotNull { line -> + val fields = line.split(Regex("\\s+")) + when { + fields.size >= 3 -> AdbReverseRule(device = fields[fields.lastIndex - 1], host = fields.last()) + fields.size == 2 -> AdbReverseRule(device = fields[0], host = fields[1]) + else -> null + } + } + .toList() +} + +internal enum class SmartSocketStatus { + OKAY, + FAIL, + UNKNOWN, +} + +internal data class SmartSocketResponse( + val status: SmartSocketStatus, + val payload: String, +) + +internal fun parseSmartSocketResponse(raw: String): SmartSocketResponse { + return when { + raw.startsWith("OKAY") -> SmartSocketResponse(SmartSocketStatus.OKAY, decodeProtocolStringOrRaw(raw.removePrefix("OKAY"))) + raw.startsWith("FAIL") -> SmartSocketResponse(SmartSocketStatus.FAIL, decodeProtocolStringOrRaw(raw.removePrefix("FAIL"))) + else -> SmartSocketResponse(SmartSocketStatus.UNKNOWN, decodeProtocolStringOrRaw(raw)) + } +} + +private fun decodeProtocolStringOrRaw(content: String): String { + if (content.length < 4) return content + val header = content.substring(0, 4) + val length = header.toIntOrNull(16) ?: return content + val payloadEnd = 4 + length + if (content.length < payloadEnd) return content + return content.substring(4, payloadEnd) +}