Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions kadb/src/commonMain/kotlin/com/flyfishxu/kadb/Kadb.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import com.flyfishxu.kadb.cert.platform.defaultDeviceName
import com.flyfishxu.kadb.core.AdbConnection
import com.flyfishxu.kadb.forwarding.TcpForwarder
import com.flyfishxu.kadb.pair.PairingConnectionCtx
import com.flyfishxu.kadb.reverse.AdbReverse
import com.flyfishxu.kadb.reverse.AdbReverseRule
import com.flyfishxu.kadb.shell.AdbShellResponse
import com.flyfishxu.kadb.shell.AdbShellStream
import com.flyfishxu.kadb.stream.AdbStream
Expand All @@ -27,6 +29,7 @@ class Kadb(
) : AutoCloseable {

private var connection: Pair<AdbConnection, TransportChannel>? = null
private val adbReverse by lazy { AdbReverse(this) }

fun connectionCheck(): Boolean = connection?.second?.isOpen == true

Expand Down Expand Up @@ -196,6 +199,26 @@ class Kadb(
return forwarder
}

fun reverseForward(device: String, host: String, noRebind: Boolean = false) {
adbReverse.create(device, host, noRebind)
}

fun reverseForward(devicePort: Int, hostPort: Int, noRebind: Boolean = false) {
reverseForward("tcp:$devicePort", "tcp:$hostPort", noRebind)
}

fun reverseKillForward(device: String) {
adbReverse.remove(device)
}

fun reverseKillAllForwards() {
adbReverse.removeAll()
}

fun reverseListForwards(): List<AdbReverseRule> {
return adbReverse.list()
}

private fun restartAdb(destination: String): String {
this.open(destination).use { stream ->
return stream.source.readUntil('\n'.code.toByte()).readString(Charsets.UTF_8)
Expand Down
144 changes: 143 additions & 1 deletion kadb/src/commonMain/kotlin/com/flyfishxu/kadb/core/AdbConnection.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.flyfishxu.kadb.core

import com.flyfishxu.kadb.cert.AdbKeyPair
import com.flyfishxu.kadb.cert.platform.defaultDeviceName
import com.flyfishxu.kadb.debug.log
import com.flyfishxu.kadb.pair.SslUtils
import com.flyfishxu.kadb.queue.AdbMessageQueue
import com.flyfishxu.kadb.stream.AdbStream
Expand All @@ -19,11 +20,19 @@ import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.security.interfaces.RSAPublicKey
import java.util.*
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.TimeUnit
import javax.net.ssl.SSLProtocolException
import kotlin.Throws
import kotlin.io.encoding.Base64
import kotlin.io.encoding.ExperimentalEncodingApi
import kotlin.concurrent.thread
import okio.BufferedSink
import okio.Source
import okio.buffer
import okio.sink
import okio.source
import java.net.Socket

internal class AdbConnection internal constructor(
adbReader: AdbReader,
Expand All @@ -36,6 +45,12 @@ internal class AdbConnection internal constructor(

private val random = Random()
private val messageQueue = AdbMessageQueue(adbReader)
private var reverseThread: Thread? = null
private val reverseSessionThreads = ConcurrentLinkedQueue<Thread>()

init {
startReverseBridgeLoop()
}

@Throws(IOException::class)
fun open(destination: String): AdbStream {
Expand All @@ -58,7 +73,11 @@ internal class AdbConnection internal constructor(
}

private fun newId(): Int {
return random.nextInt()
var id: Int
do {
id = random.nextInt()
} while (id == 0)
return id
}

@TestOnly
Expand All @@ -68,14 +87,108 @@ internal class AdbConnection internal constructor(

override fun close() {
try {
reverseThread?.interrupt()
reverseThread = null
while (true) {
val t = reverseSessionThreads.poll() ?: break
t.interrupt()
}
runCatching { messageQueue.stopListening(REVERSE_LISTENER_ID) }
messageQueue.close()
adbWriter.close()
closeable?.close()
} catch (_: Throwable) {
}
}

private fun startReverseBridgeLoop() {
messageQueue.startListening(REVERSE_LISTENER_ID)
reverseThread = thread(name = "kadb-reverse-accept") {
while (!Thread.currentThread().isInterrupted) {
try {
val openMessage = messageQueue.take(REVERSE_LISTENER_ID, AdbProtocol.CMD_OPEN)
handleIncomingOpen(openMessage)
} catch (_: InterruptedException) {
return@thread
} catch (t: Throwable) {
if (!Thread.currentThread().isInterrupted) {
log { "reverse bridge loop error: ${t.message}" }
}
}
}
}
}

private fun handleIncomingOpen(message: AdbMessage) {
val destination = extractOpenDestination(message) ?: run {
adbWriter.writeClose(0, message.arg0)
return
}

val target = parseReverseTcpTarget(destination) ?: run {
adbWriter.writeClose(0, message.arg0)
return
}

val localSocket = runCatching { Socket(target.host, target.port) }.getOrElse {
adbWriter.writeClose(0, message.arg0)
return
}

val localId = newId()
messageQueue.startListening(localId)
adbWriter.writeOkay(localId, message.arg0)

val adbStream = AdbStream(messageQueue, adbWriter, maxPayloadSize, localId, message.arg0)

val socketSource = localSocket.getInputStream().source()
val socketSink = localSocket.getOutputStream().sink().buffer()

val readerThread = thread(name = "kadb-reverse-local-to-device") {
try {
bridge(socketSource, adbStream.sink)
} finally {
reverseSessionThreads.remove(Thread.currentThread())
}
}
reverseSessionThreads.add(readerThread)

val writerThread = thread(name = "kadb-reverse-device-to-local") {
try {
bridge(adbStream.source, socketSink)
} finally {
runCatching { adbStream.close() }
runCatching { localSocket.close() }
readerThread.interrupt()
reverseSessionThreads.remove(Thread.currentThread())
}
}
reverseSessionThreads.add(writerThread)
}

private fun extractOpenDestination(message: AdbMessage): String? {
if (message.payloadLength <= 0) return null
val bytes = message.payload
val endExclusive = if (bytes[message.payloadLength - 1].toInt() == 0) {
message.payloadLength - 1
} else {
message.payloadLength
}
if (endExclusive <= 0) return null
return String(bytes, 0, endExclusive)
}

private fun bridge(source: Source, sink: BufferedSink) {
while (!Thread.currentThread().isInterrupted) {
val read = runCatching { source.read(sink.buffer, 256) }.getOrElse { return }
if (read < 0) return
runCatching { sink.flush() }.getOrElse { return }
}
}

companion object {
private const val REVERSE_LISTENER_ID = 0

suspend fun connect(
host: String,
port: Int,
Expand Down Expand Up @@ -182,6 +295,35 @@ internal class AdbConnection internal constructor(
}
}

internal data class ReverseTcpTarget(
val host: String,
val port: Int,
)

internal fun parseReverseTcpTarget(destination: String): ReverseTcpTarget? {
if (!destination.startsWith("tcp:")) return null
val raw = destination.removePrefix("tcp:")
if (raw.isBlank()) return null

val segments = raw.split(':')
return when (segments.size) {
1 -> {
val port = segments[0].toIntOrNull() ?: return null
if (port !in 1..65535) return null
ReverseTcpTarget(host = "127.0.0.1", port = port)
}

2 -> {
val host = segments[0].ifBlank { return null }
val port = segments[1].toIntOrNull() ?: return null
if (port !in 1..65535) return null
ReverseTcpTarget(host = host, port = port)
}

else -> null
}
}


/*** ADB RSA Public Key Transformation Section ***/
private const val KEY_LENGTH_BITS = 2048
Expand Down
117 changes: 117 additions & 0 deletions kadb/src/commonMain/kotlin/com/flyfishxu/kadb/reverse/ReverseCore.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package com.flyfishxu.kadb.reverse

import com.flyfishxu.kadb.Kadb

data class AdbReverseRule(
val device: String,
val host: String,
)

internal abstract class BaseReverse {
protected fun createRule(device: String, host: String, noRebind: Boolean = false) {
execute(buildReverseForwardDestination(device, host, noRebind))
}

protected fun removeRule(device: String) {
execute(buildReverseKillDestination(device))
}

protected fun removeAllRules() {
execute(buildReverseKillAllDestination())
}

protected fun listRules(): List<AdbReverseRule> {
return parseReverseListOutput(execute(buildReverseListDestination()))
}

protected abstract fun execute(destination: String): String
}

internal class AdbReverse(
private val kadb: Kadb,
) : BaseReverse() {
fun create(device: String, host: String, noRebind: Boolean = false) {
createRule(device, host, noRebind)
}

fun remove(device: String) {
removeRule(device)
}

fun removeAll() {
removeAllRules()
}

fun list(): List<AdbReverseRule> {
return listRules()
}

override fun execute(destination: String): String {
return kadb.open(destination).use { stream ->
val output = stream.source.readUtf8()
val response = parseSmartSocketResponse(output)
check(response.status != SmartSocketStatus.FAIL) { "Reverse command failed: ${response.payload}" }
response.payload
}
}
}

internal fun buildReverseForwardDestination(device: String, host: String, noRebind: Boolean = false): String {
require(device.isNotBlank()) { "device must not be blank" }
require(host.isNotBlank()) { "host must not be blank" }
val prefix = if (noRebind) "reverse:forward:norebind:" else "reverse:forward:"
return "$prefix$device;$host"
}

internal fun buildReverseKillDestination(device: String): String {
require(device.isNotBlank()) { "device must not be blank" }
return "reverse:killforward:$device"
}

internal fun buildReverseKillAllDestination(): String = "reverse:killforward-all"

internal fun buildReverseListDestination(): String = "reverse:list-forward"

internal fun parseReverseListOutput(output: String): List<AdbReverseRule> {
return output
.lineSequence()
.map { it.trim() }
.filter { it.isNotBlank() }
.mapNotNull { line ->
val fields = line.split(Regex("\\s+"))
when {
fields.size >= 3 -> AdbReverseRule(device = fields[fields.lastIndex - 1], host = fields.last())
fields.size == 2 -> AdbReverseRule(device = fields[0], host = fields[1])
else -> null
}
}
.toList()
}

internal enum class SmartSocketStatus {
OKAY,
FAIL,
UNKNOWN,
}

internal data class SmartSocketResponse(
val status: SmartSocketStatus,
val payload: String,
)

internal fun parseSmartSocketResponse(raw: String): SmartSocketResponse {
return when {
raw.startsWith("OKAY") -> SmartSocketResponse(SmartSocketStatus.OKAY, decodeProtocolStringOrRaw(raw.removePrefix("OKAY")))
raw.startsWith("FAIL") -> SmartSocketResponse(SmartSocketStatus.FAIL, decodeProtocolStringOrRaw(raw.removePrefix("FAIL")))
else -> SmartSocketResponse(SmartSocketStatus.UNKNOWN, decodeProtocolStringOrRaw(raw))
}
}

private fun decodeProtocolStringOrRaw(content: String): String {
if (content.length < 4) return content
val header = content.substring(0, 4)
val length = header.toIntOrNull(16) ?: return content
val payloadEnd = 4 + length
if (content.length < payloadEnd) return content
return content.substring(4, payloadEnd)
}