diff --git a/services/CHANGELOG.md b/services/CHANGELOG.md index 9d62f882d..8edd877ae 100644 --- a/services/CHANGELOG.md +++ b/services/CHANGELOG.md @@ -11,6 +11,12 @@ **New Features** +* Adding a LocalSocket-based protocol for the ShellExecutor to talk to the + ShellMain. This obsoletes SpeakEasy; if androidx.test.services is killed + (e.g. by the low memory killer) between the start of the app_process that + invokes LocalSocketShellMain and the start of the test, the test is still able + to talk to LocalSocketShellMain. + **Breaking Changes** **API Changes** diff --git a/services/shellexecutor/java/androidx/test/services/shellexecutor/BUILD b/services/shellexecutor/java/androidx/test/services/shellexecutor/BUILD index c1257b98d..897bc547a 100644 --- a/services/shellexecutor/java/androidx/test/services/shellexecutor/BUILD +++ b/services/shellexecutor/java/androidx/test/services/shellexecutor/BUILD @@ -29,6 +29,32 @@ kt_android_library( ], ) +proto_library( + name = "local_socket_protocol_pb", + srcs = ["local_socket_protocol.proto"], +) + +java_lite_proto_library( + name = "local_socket_protocol_pb_java_proto_lite", + visibility = [ + "//services/shellexecutor/javatests/androidx/test/services/shellexecutor:__subpackages__", + ], + deps = [":local_socket_protocol_pb"], +) + +kt_android_library( + name = "local_socket_protocol", + srcs = ["LocalSocketProtocol.kt"], + visibility = [ + "//services/shellexecutor/javatests/androidx/test/services/shellexecutor:__subpackages__", + ], + deps = [ + ":local_socket_protocol_pb_java_proto_lite", + "@com_google_protobuf//:protobuf_javalite", + "@maven//:org_jetbrains_kotlinx_kotlinx_coroutines_core", + ], +) + kt_android_library( name = "exec_server", srcs = [ @@ -62,6 +88,7 @@ kt_android_library( "ShellCommand.java", "ShellCommandClient.java", "ShellCommandFileObserverClient.kt", + "ShellCommandLocalSocketClient.kt", "ShellExecSharedConstants.java", "ShellExecutor.java", "ShellExecutorFactory.java", @@ -73,6 +100,8 @@ kt_android_library( deps = [ ":coroutine_file_observer", ":file_observer_protocol", + ":local_socket_protocol", + ":local_socket_protocol_pb_java_proto_lite", "//services/speakeasy/java/androidx/test/services/speakeasy:protocol", "//services/speakeasy/java/androidx/test/services/speakeasy/client", "//services/speakeasy/java/androidx/test/services/speakeasy/client:tool_connection", diff --git a/services/shellexecutor/java/androidx/test/services/shellexecutor/LocalSocketProtocol.kt b/services/shellexecutor/java/androidx/test/services/shellexecutor/LocalSocketProtocol.kt new file mode 100644 index 000000000..4949d232d --- /dev/null +++ b/services/shellexecutor/java/androidx/test/services/shellexecutor/LocalSocketProtocol.kt @@ -0,0 +1,152 @@ +/* + * Copyright (C) 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package androidx.test.services.shellexecutor + +import android.net.LocalSocket +import android.net.LocalSocketAddress +import android.util.Log +import androidx.test.services.shellexecutor.LocalSocketProtocolProto.RunCommandRequest +import androidx.test.services.shellexecutor.LocalSocketProtocolProto.RunCommandResponse +import com.google.protobuf.ByteString +import java.io.IOException +import java.net.URLDecoder +import java.net.URLEncoder +import kotlin.time.Duration + +/** + * Protocol for ShellCommandLocalSocketClient to talk to ShellCommandLocalSocketExecutorServer. + * + * Since androidx.test.services already includes the protobuf runtime, we aren't paying much extra + * for adding some more protos to ship back and forth, which is vastly easier to deal with than + * PersistableBundles (which don't even support ByteArray types). + * + * A conversation consists of a single RunCommandRequest from the client followed by a stream of + * RunCommandResponses from the server; the final response has an exit code. + */ +object LocalSocketProtocol { + /** + * Composes a RunCommandRequest and sends it over the LocalSocket. + * + * @param secret The secret to authenticate the request. + * @param argv The argv of the command line to run. + * @param env The environment variables to provide to the process. + * @param timeout The timeout for the command; infinite or nonpositive values mean no timeout. + */ + fun LocalSocket.sendRequest( + secret: String, + argv: List, + env: Map? = null, + timeout: Duration, + ) { + val builder = RunCommandRequest.newBuilder().setSecret(secret).addAllArgv(argv) + env?.forEach { (k, v) -> builder.putEnvironment(k, v) } + if (timeout.isInfinite() || timeout.isNegative() || timeout == Duration.ZERO) { + builder.setTimeoutMs(0) // <= 0 means no timeout + } else { + builder.setTimeoutMs(timeout.inWholeMilliseconds) + } + builder.build().writeDelimitedTo(outputStream) + } + + /** Reads a RunCommandRequest from the LocalSocket. */ + fun LocalSocket.readRequest(): RunCommandRequest { + return RunCommandRequest.parseDelimitedFrom(inputStream)!! + } + + /** Composes a RunCommandResponse and sends it over the LocalSocket. */ + fun LocalSocket.sendResponse( + buffer: ByteArray? = null, + size: Int = 0, + exitCode: Int? = null, + ): Boolean { + val builder = RunCommandResponse.newBuilder() + buffer?.let { + val bufferSize = if (size > 0) size else it.size + builder.buffer = ByteString.copyFrom(it, 0, bufferSize) + } + // Since we're currently stuck on a version of protobuf where we don't have hasExitCode(), we + // use a magic value to indicate that exitCode is not set. When we upgrade to a newer version + // of protobuf, we can obsolete this. + if (exitCode != null) { + builder.exitCode = exitCode + } else { + builder.exitCode = HAS_NOT_EXITED + } + + try { + builder.build().writeDelimitedTo(outputStream) + } catch (x: IOException) { + // Sadly, the only way to discover that the client cut the connection is an exception that + // can only be distinguished by its text. + if (x.message.equals("Broken pipe")) { + Log.i(TAG, "LocalSocket stream closed early") + } else { + Log.w(TAG, "LocalSocket write failed", x) + } + return false + } + return true + } + + /** Reads a RunCommandResponse from the LocalSocket. */ + fun LocalSocket.readResponse(): RunCommandResponse? { + return RunCommandResponse.parseDelimitedFrom(inputStream) + } + + /** + * Is this the end of the stream? + * + * Once we upgrade to a newer version of protobuf, we can switch to hasExitCode(). + */ + fun RunCommandResponse.hasExited() = exitCode != HAS_NOT_EXITED + + /** + * Builds a "binder key", given the server address and secret. (We are not actually using a Binder + * here, but the ShellExecutor interface calls the secret for connecting client to server a + * "binder key", so we stick with that nomenclature.) Binder keys should be opaque outside + * this directory. + * + * The address can contain spaces, and since it gets passed through a command line, we need to + * encode it so it doesn't get split by argv. java.net.URLEncoder is conveniently available on all + * SDK versions. + */ + @JvmStatic + fun LocalSocketAddress.asBinderKey(secret: String) = buildString { + append(":") + append(URLEncoder.encode(name, "UTF-8")) // Will convert any : to %3A + append(":") + append(URLEncoder.encode(secret, "UTF-8")) + append(":") + } + + /** Extracts the address from a binder key. */ + @JvmStatic + fun addressFromBinderKey(binderKey: String) = + LocalSocketAddress(URLDecoder.decode(binderKey.split(":")[1], "UTF-8")) + + /** Extracts the secret from a binder key. */ + @JvmStatic + fun secretFromBinderKey(binderKey: String) = URLDecoder.decode(binderKey.split(":")[2], "UTF-8") + + /** Is this a valid binder key? */ + @JvmStatic + fun isBinderKey(maybeKey: String) = + maybeKey.startsWith(':') && maybeKey.endsWith(':') && maybeKey.split(":").size == 4 + + const val TAG = "LocalSocketProtocol" + private const val HAS_NOT_EXITED = 0xCA7F00D +} diff --git a/services/shellexecutor/java/androidx/test/services/shellexecutor/ShellCommandLocalSocketClient.kt b/services/shellexecutor/java/androidx/test/services/shellexecutor/ShellCommandLocalSocketClient.kt new file mode 100644 index 000000000..103a5b190 --- /dev/null +++ b/services/shellexecutor/java/androidx/test/services/shellexecutor/ShellCommandLocalSocketClient.kt @@ -0,0 +1,171 @@ +/* + * Copyright (C) 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package androidx.test.services.shellexecutor + +import android.net.LocalSocket +import android.net.LocalSocketAddress +import android.os.Build +import android.util.Log +import androidx.test.services.shellexecutor.LocalSocketProtocol.addressFromBinderKey +import androidx.test.services.shellexecutor.LocalSocketProtocol.hasExited +import androidx.test.services.shellexecutor.LocalSocketProtocol.readResponse +import androidx.test.services.shellexecutor.LocalSocketProtocol.secretFromBinderKey +import androidx.test.services.shellexecutor.LocalSocketProtocol.sendRequest +import java.io.IOException +import java.io.InputStream +import java.io.PipedInputStream +import java.io.PipedOutputStream +import java.util.concurrent.Executors +import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.measureTime +import kotlin.time.toKotlinDuration +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.asCoroutineDispatcher +import kotlinx.coroutines.delay +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.runInterruptible +import kotlinx.coroutines.withTimeout + +/** + * Client that sends requests to the ShellCommandLocalSocketExecutorServer. + * + * This client is designed to be callable from Java. + */ +class ShellCommandLocalSocketClient(binderKey: String) { + private val address: LocalSocketAddress = addressFromBinderKey(binderKey) + private val secret: String = secretFromBinderKey(binderKey) + private lateinit var socket: LocalSocket + + /** + * Composes a request and sends it to the server, and streams the resulting output. + * @param command The command to run. + * @param parameters The parameters to the command. command + parameters = argv + * @param shellEnv The environment variables to provide to the process. + * @param executeThroughShell Whether to execute the command through a shell, making the argv + * "sh" "-c" "command parameters". + * @param timeout The timeout for the command; infinite or nonpositive values mean no timeout. + * @return An InputStream that can be used to read the output of the command. + */ + @kotlin.time.ExperimentalTime + fun request( + command: String?, + parameters: List?, + shellEnv: Map?, + executeThroughShell: Boolean, + timeout: Duration, + ): InputStream { + if (command == null || command.isEmpty()) { + throw IllegalArgumentException("Null or empty command") + } + + lateinit var result: InputStream + + // The call to runBlocking causes Android to emit "art: Note: end time exceeds epoch:". This is + // in InitTimeSpec in runtime/utils.cc. I don't see a way to invoke it in such a way that it + // doesn't clutter the logcat. + runBlocking(scope.coroutineContext) { + withTimeout(timeout) { + runInterruptible { + socket = LocalSocket(LocalSocket.SOCKET_STREAM) + // While there *is* a timeout option on connect(), in the Android source, it throws + // UnsupportedOperationException! So we leave the timeout up to withTimeout + + // runInterruptible. Capture the time taken to connect so we can subtract it from the + // overall timeout. (Calling socket.setSoTimeout() before connect() throws IOException + // "socket not created".) + val connectTime = measureTime { socket.connect(address) } + + val argv = mutableListOf() + if (executeThroughShell) { + argv.addAll(listOf("sh", "-c")) + argv.add((listOf(command) + (parameters ?: emptyList())).joinToString(" ")) + } else { + argv.add(command) + parameters?.let { argv.addAll(it) } + } + + socket.sendRequest(secret, argv, shellEnv, timeout - connectTime) + socket.shutdownOutput() + + // We read responses off the socket, write buffers to the pipe, and close the pipe when we + // get an exit code. The existing ShellExecutor API doesn't provide for *returning* that + // exit code, but it's useful as a way to know when to close the stream. By using the pipe + // as an intermediary, we can respond to exceptions sensibly. + val upstream = PipedOutputStream() + val downstream = PipedInputStream(upstream) + + scope.launch { + try { + socket.inputStream.use { + while (true) { + val response = socket.readResponse() + if (response == null) break // EOF + if (response.buffer.size() > 0) response.buffer.writeTo(upstream) + if (response.hasExited()) { + Log.i(TAG, "Process ${argv[0]} exited with code ${response.exitCode}") + break + } + } + } + } catch (x: IOException) { + if (x.isPipeClosed()) { + Log.i(TAG, "LocalSocket relay for ${argv[0]} closed early") + } else { + Log.w(TAG, "LocalSocket relay for ${argv[0]} failed", x) + } + } finally { + upstream.flush() + upstream.close() + } + } + + result = downstream + } + } + } + return result + } + + /** Java-friendly wrapper for the above. */ + @kotlin.time.ExperimentalTime + fun request( + command: String?, + parameters: List?, + shellEnv: Map?, + executeThroughShell: Boolean, + timeout: java.time.Duration, + ): InputStream = + request(command, parameters, shellEnv, executeThroughShell, timeout.toKotlinDuration()) + + private companion object { + private const val TAG = "SCLSClient" // up to 23 characters + + // Keep this around for all clients; if you create a new one with every object, you can wind up + // running out of threads. + private val scope = CoroutineScope(Executors.newCachedThreadPool().asCoroutineDispatcher()) + } +} + +// Sadly, the only way to distinguish the downstream pipe being closed is the text +// of the exception thrown when you try to write to it. Which varies by API level. +private fun IOException.isPipeClosed() = + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) { + message.equals("Pipe closed") + } else { + message.equals("Pipe is closed") + } diff --git a/services/shellexecutor/java/androidx/test/services/shellexecutor/local_socket_protocol.proto b/services/shellexecutor/java/androidx/test/services/shellexecutor/local_socket_protocol.proto new file mode 100644 index 000000000..326befeea --- /dev/null +++ b/services/shellexecutor/java/androidx/test/services/shellexecutor/local_socket_protocol.proto @@ -0,0 +1,50 @@ +// +// Copyright (C) 2024 The Android Open Source Project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package androidx.test.services.storage; + +option java_package = "androidx.test.services.shellexecutor"; +option java_outer_classname = 'LocalSocketProtocolProto'; + +// Message sent from client to server to start a process. +message RunCommandRequest { + // Secret to authenticate the request. + string secret = 1; + + // argv of the command line to run. + repeated string argv = 2; + + // Environment varialbes to provide. + map environment = 3; + + // Timeout for the command. Any value <= 0 is treated as "forever". + int64 timeout_ms = 4; +} + +// Multiple responses can be streamed back to the client. The one that has an +// exit code indicates the end of the stream. +message RunCommandResponse { + // A buffer of the command's output (stdout and stderr combined by specifying + // redirectErrorStream(true) on ProcessBuilder). + bytes buffer = 1; + + // The exit code of the command. While we're stuck on proto3, the magic value + // 0xCA7F00D indicates that the command is still running; once we can move to + // a newer version where we can test hasExitCode(), we will remove the magic + // value. + int32 exit_code = 2; +} diff --git a/services/shellexecutor/javatests/androidx/test/services/shellexecutor/BUILD b/services/shellexecutor/javatests/androidx/test/services/shellexecutor/BUILD index 0655cd59d..0b69391ab 100644 --- a/services/shellexecutor/javatests/androidx/test/services/shellexecutor/BUILD +++ b/services/shellexecutor/javatests/androidx/test/services/shellexecutor/BUILD @@ -79,6 +79,22 @@ axt_android_library_test( ], ) +axt_android_library_test( + name = "ShellCommandLocalSocketClientTest", + srcs = [ + "ShellCommandLocalSocketClientTest.kt", + ], + deps = [ + "//runner/monitor", + "//services/shellexecutor:exec_client", + "//services/shellexecutor/java/androidx/test/services/shellexecutor:local_socket_protocol", + "//services/shellexecutor/java/androidx/test/services/shellexecutor:local_socket_protocol_pb_java_proto_lite", + "@maven//:com_google_truth_truth", + "@maven//:junit_junit", + "@maven//:org_jetbrains_kotlinx_kotlinx_coroutines_android", + ], +) + axt_android_library_test( name = "ShellCommandFileObserverExecutorServerTest", srcs = [ diff --git a/services/shellexecutor/javatests/androidx/test/services/shellexecutor/ShellCommandLocalSocketClientTest.kt b/services/shellexecutor/javatests/androidx/test/services/shellexecutor/ShellCommandLocalSocketClientTest.kt new file mode 100644 index 000000000..018c355a9 --- /dev/null +++ b/services/shellexecutor/javatests/androidx/test/services/shellexecutor/ShellCommandLocalSocketClientTest.kt @@ -0,0 +1,115 @@ +/* + * Copyright (C) 2024 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package androidx.test.services.shellexecutor + +import android.net.LocalServerSocket +import android.net.LocalSocketAddress +import androidx.test.services.shellexecutor.LocalSocketProtocol.addressFromBinderKey +import androidx.test.services.shellexecutor.LocalSocketProtocol.asBinderKey +import androidx.test.services.shellexecutor.LocalSocketProtocol.readRequest +import androidx.test.services.shellexecutor.LocalSocketProtocol.secretFromBinderKey +import androidx.test.services.shellexecutor.LocalSocketProtocolProto.RunCommandRequest +import com.google.common.truth.Truth.assertThat +import kotlin.time.Duration.Companion.seconds +import kotlinx.coroutines.async +import kotlinx.coroutines.runBlocking +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 + +@RunWith(JUnit4::class) +class ShellCommandLocalSocketClientTest { + + @Before fun setUp() {} + + @Test + fun binderkey_success() { + val address = LocalSocketAddress("binderkey_success 12345") + val binderKey = address.asBinderKey(SECRET) + assertThat(addressFromBinderKey(binderKey).name).isEqualTo(address.name) + assertThat(addressFromBinderKey(binderKey).namespace).isEqualTo(address.namespace) + assertThat(secretFromBinderKey(binderKey)).isEqualTo(SECRET) + } + + @Test + fun request_regular() { + val server = LocalServerSocket("request_regular") + val client = ShellCommandLocalSocketClient(server.localSocketAddress.asBinderKey(SECRET)) + + val request: RunCommandRequest + + runBlocking { + val result = async { + val socket = server.accept() + socket.readRequest() + } + + client.request( + "foo", + listOf("bar", "baz"), + mapOf("quem" to "quux", "potrzebie" to "furshlugginer"), + executeThroughShell = false, + timeout = 1.seconds, + ) + request = result.await() + } + + assertThat(request.secret).isEqualTo(SECRET) + assertThat(request.argvList).containsExactly("foo", "bar", "baz") + assertThat(request.environmentMap) + .containsExactlyEntriesIn(mapOf("quem" to "quux", "potrzebie" to "furshlugginer")) + // The overall timeout will have the connect time shaved off. This is usually quite low, but + // I've seen it as high as 61ms. + assertThat(request.timeoutMs).isGreaterThan(900) + } + + @Test + fun request_executeThroughShell() { + val server = LocalServerSocket("request_executeThroughShell") + val client = ShellCommandLocalSocketClient(server.localSocketAddress.asBinderKey(SECRET)) + + val request: RunCommandRequest + + runBlocking { + val result = async { + val socket = server.accept() + socket.readRequest() + } + + client.request( + "foo", + listOf("bar", "baz"), + mapOf("quem" to "quux", "potrzebie" to "furshlugginer"), + executeThroughShell = true, + timeout = 1.seconds, + ) + request = result.await() + } + + assertThat(request.secret).isEqualTo(SECRET) + assertThat(request.argvList).containsExactly("sh", "-c", "foo bar baz") + assertThat(request.environmentMap) + .containsExactlyEntriesIn(mapOf("quem" to "quux", "potrzebie" to "furshlugginer")) + // The overall timeout will have the connect time shaved off. + assertThat(request.timeoutMs).isGreaterThan(900) + } + + private companion object { + const val SECRET = "foo:bar" + } +}