From ffac0c4bdf37575dc14fad97e89a69aec78a6f71 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Thu, 31 Jul 2025 19:19:14 +0200 Subject: [PATCH 1/9] grpc-pb: Refactor subtyping hierarchy Signed-off-by: Johannes Zottele --- .../kotlin/kotlinx/rpc/grpc/pb/Message.kt | 21 +++++++++++++++ .../kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt | 26 ++++++++++--------- .../grpc-core/src/commonTest/proto/mini.proto | 8 ++++++ .../src/commonTest/proto/repeated.proto | 5 +--- .../protobuf/ModelToKotlinCommonGenerator.kt | 25 ++++++++++++------ .../kotlinx/rpc/protobuf/model/model.kt | 2 +- 6 files changed, 62 insertions(+), 25 deletions(-) create mode 100644 grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/Message.kt create mode 100644 grpc/grpc-core/src/commonTest/proto/mini.proto diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/Message.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/Message.kt new file mode 100644 index 00000000..5e64a4d0 --- /dev/null +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/Message.kt @@ -0,0 +1,21 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.pb + +import kotlinx.rpc.internal.utils.InternalRpcApi + +@InternalRpcApi +public abstract class Message { + + @InternalRpcApi + public interface Companion { + + public fun decodeWith(decoder: WireDecoder): T + + } + + public abstract fun encodeWith(encoder: WireEncoder) + +} \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index f12e932f..32dac983 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -11,19 +11,18 @@ import kotlin.test.assertEquals class ProtosTest { - private fun decodeEncode( + private fun decodeEncode( msg: T, - enc: T.(WireEncoder) -> Unit, - dec: (WireDecoder) -> T? - ): T? { + decoder: (WireDecoder) -> T, + ): T { val buffer = Buffer() val encoder = WireEncoder(buffer) - msg.enc(encoder) + msg.encodeWith(encoder) encoder.flush() return WireDecoder(buffer).use { - dec(it) + decoder(it) } } @@ -48,9 +47,12 @@ class ProtosTest { bytes = byteArrayOf(1, 2, 3) } - val decoded = decodeEncode(msg, { encodeWith(it) }, AllPrimitivesCommon::decodeWith) + val msgObj = msg as Message - assertEquals(msg.double, decoded?.double) + val decoded = decodeEncode(msgObj, AllPrimitivesCommonBuilder::decodeWith) + as AllPrimitivesCommon + + assertEquals(msg.double, decoded.double) } @Test @@ -61,11 +63,11 @@ class ProtosTest { listString = listOf("a", "b", "c") } - val decoded = decodeEncode(msg, { encodeWith(it) }, RepeatedCommon::decodeWith) + val decoded = decodeEncode(msg as Message, RepeatedCommonBuilder::decodeWith) as RepeatedCommonBuilder - assertEquals(msg.listInt32, decoded?.listInt32) - assertEquals(msg.listFixed32, decoded?.listFixed32) - assertEquals(msg.listString, decoded?.listString) + assertEquals(msg.listInt32, decoded.listInt32) + assertEquals(msg.listFixed32, decoded.listFixed32) + assertEquals(msg.listString, decoded.listString) } } diff --git a/grpc/grpc-core/src/commonTest/proto/mini.proto b/grpc/grpc-core/src/commonTest/proto/mini.proto new file mode 100644 index 00000000..8ef15f6d --- /dev/null +++ b/grpc/grpc-core/src/commonTest/proto/mini.proto @@ -0,0 +1,8 @@ +syntax = "proto3"; + +package kotlinx.rpc.grpc.test.common; + +message MiniMsg { + int32 MiniField = 1; +} + diff --git a/grpc/grpc-core/src/commonTest/proto/repeated.proto b/grpc/grpc-core/src/commonTest/proto/repeated.proto index 55d1a2a1..93c7ce58 100644 --- a/grpc/grpc-core/src/commonTest/proto/repeated.proto +++ b/grpc/grpc-core/src/commonTest/proto/repeated.proto @@ -6,8 +6,5 @@ message RepeatedCommon { repeated fixed32 listFixed32 = 1 [packed = true]; repeated int32 listInt32 = 2 [packed = false]; repeated string listString = 3; - - message InnerClass { - - } } + diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index fadeaebc..bc88e144 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -100,8 +100,6 @@ class ModelToKotlinCommonGenerator( fileDeclaration.messageDeclarations.forEach { generateMessageConstructor(it) - generateMessageDecoder(it) - generateMessageEncoder(it) } } @@ -140,10 +138,11 @@ class ModelToKotlinCommonGenerator( @Suppress("detekt.CyclomaticComplexMethod") private fun CodeGenerator.generateInternalMessage(declaration: MessageDeclaration) { + val builderClassName = "${declaration.name.simpleName}Builder" clazz( - name = "${declaration.name.simpleName}Builder", + name = builderClassName, declarationType = DeclarationType.Class, - superTypes = listOf(declaration.name.safeFullName()), + superTypes = listOf(declaration.name.safeFullName(), "Message()"), ) { declaration.fields().forEach { (fieldDeclaration, field) -> val value = when { @@ -168,6 +167,12 @@ class ModelToKotlinCommonGenerator( declaration.nestedDeclarations.forEach { nested -> generateInternalMessage(nested) } + + generateMessageEncoder(declaration) + + scope("companion object: Message.Companion<$builderClassName>") { + generateMessageDecoder(declaration) + } } } @@ -184,8 +189,8 @@ class ModelToKotlinCommonGenerator( private fun CodeGenerator.generateMessageDecoder(declaration: MessageDeclaration) = function( name = "decodeWith", args = "decoder: WireDecoder", - contextReceiver = "${declaration.name.safeFullName()}.Companion", - returnType = "${declaration.name.safeFullName()}?" + modifiers = "override", + returnType = declaration.name.simpleName + "Builder" ) { code("val msg = ${declaration.name.safeFullName("Builder")}()") whileBlock("!decoder.hadError()") { @@ -197,7 +202,7 @@ class ModelToKotlinCommonGenerator( } ifBranch( condition = "decoder.hadError()", - ifBlock = { code("return null") } + ifBlock = { code("error(\"Error during decoding of ${declaration.name.simpleName}\")") } ) // TODO: Make a lists immutable @@ -231,8 +236,12 @@ class ModelToKotlinCommonGenerator( private fun CodeGenerator.generateMessageEncoder(declaration: MessageDeclaration) = function( name = "encodeWith", args = "encoder: WireEncoder", - contextReceiver = declaration.name.safeFullName(), + modifiers = "override" ) { + if (declaration.fields().isEmpty()) { + code("// no fields to encode") + return@function + } declaration.fields().forEach { (_, field) -> val fieldName = field.name diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt index 57267380..70f093ea 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt @@ -67,7 +67,7 @@ data class FieldDeclaration( val packedFixedSize = type.wireType == WireType.FIXED64 || type.wireType == WireType.FIXED32 // aligns with edition settings and backward compatibility with proto2 and proto3 - val nullable: Boolean = dec.hasPresence() && !dec.isRequired && !dec.hasDefaultValue() + val nullable: Boolean = dec.hasPresence() && !dec.isRequired && !dec.hasDefaultValue() && !dec.isRepeated val number: Int = dec.number } From ff36ce83736c5280286855dcd23b2ae6b1e6ca16 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Fri, 1 Aug 2025 11:00:50 +0200 Subject: [PATCH 2/9] grpc-pb: Add presence tracking and required field check Signed-off-by: Johannes Zottele --- .../kotlinx/rpc/grpc/internal/BitSet.kt | 63 ++++ .../kotlin/kotlinx/rpc/grpc/pb/Message.kt | 6 +- .../kotlinx/rpc/grpc/internal/BitSetTest.kt | 304 ++++++++++++++++++ .../kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt | 38 ++- .../grpc-core/src/commonTest/proto/mini.proto | 8 - .../src/commonTest/proto/presence_check.proto | 9 + .../src/commonTest/proto/repeated.proto | 8 +- .../kotlinx/rpc/protobuf/CodeGenerator.kt | 10 + .../protobuf/ModelToKotlinCommonGenerator.kt | 68 +++- .../rpc/protobuf/codeRequestToModel.kt | 13 +- .../kotlinx/rpc/protobuf/model/model.kt | 6 +- 11 files changed, 503 insertions(+), 30 deletions(-) create mode 100644 grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/BitSet.kt create mode 100644 grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/internal/BitSetTest.kt delete mode 100644 grpc/grpc-core/src/commonTest/proto/mini.proto create mode 100644 grpc/grpc-core/src/commonTest/proto/presence_check.proto diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/BitSet.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/BitSet.kt new file mode 100644 index 00000000..16e7bce1 --- /dev/null +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/BitSet.kt @@ -0,0 +1,63 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.internal + +/** + * A fixed-sized vector of bits, allowing one to set/clear/read bits from it by a bit index. + */ +public class BitSet(public val size: Int) { + private val data: LongArray = LongArray((size + 63) ushr 6) + + /** Sets the bit at [index] to 1. */ + public fun set(index: Int) { + require(index >= 0 && index < size) { "Index $index out of bounds for length $size" } + val word = index ushr 6 + data[word] = data[word] or (1L shl (index and 63)) + } + + /** Clears the bit at [index] (sets to 0). */ + public fun clear(index: Int) { + require(index >= 0 && index < size) { "Index $index out of bounds for length $size" } + val word = index ushr 6 + data[word] = data[word] and (1L shl (index and 63)).inv() + } + + /** Returns true if the bit at [index] is set. */ + public operator fun get(index: Int): Boolean { + require(index >= 0 && index < size) { "Index $index out of bounds for length $size" } + val word = index ushr 6 + return (data[word] ushr (index and 63) and 1L) != 0L + } + + /** Clears all bits. */ + public fun clearAll() { + data.fill(0L) + } + + /** Returns the number of bits set to 1. */ + public fun cardinality(): Int { + var sum = 0 + for (w in data) { + sum += w.countOneBits() + } + return sum + } + + /** Returns true if all bits are set. */ + public fun allSet(): Boolean { + val fullWords = size ushr 6 + // check full 64-bit words + for (i in 0 until fullWords) { + if (data[i] != -1L) return false + } + // check leftover bits + val rem = size and 63 + if (rem != 0) { + val mask = (-1L ushr (64 - rem)) + if (data[fullWords] != mask) return false + } + return true + } +} \ No newline at end of file diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/Message.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/Message.kt index 5e64a4d0..f9b57a8d 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/Message.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/Message.kt @@ -4,10 +4,13 @@ package kotlinx.rpc.grpc.pb +import kotlinx.rpc.grpc.internal.BitSet import kotlinx.rpc.internal.utils.InternalRpcApi @InternalRpcApi -public abstract class Message { +public abstract class Message(fieldsWithPresence: Int) { + + public val presenceMask: BitSet = BitSet(fieldsWithPresence) @InternalRpcApi public interface Companion { @@ -17,5 +20,4 @@ public abstract class Message { } public abstract fun encodeWith(encoder: WireEncoder) - } \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/internal/BitSetTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/internal/BitSetTest.kt new file mode 100644 index 00000000..1ab39751 --- /dev/null +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/internal/BitSetTest.kt @@ -0,0 +1,304 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc.internal + +import kotlin.test.* + +class BitSetTest { + + @Test + fun testConstructor() { + // Test with size 0 + val bitSet0 = BitSet(0) + assertEquals(0, bitSet0.size) + assertEquals(0, bitSet0.cardinality()) + + // Test with small size + val bitSet10 = BitSet(10) + assertEquals(10, bitSet10.size) + assertEquals(0, bitSet10.cardinality()) + + // Test with size that spans multiple words + val bitSet100 = BitSet(100) + assertEquals(100, bitSet100.size) + assertEquals(0, bitSet100.cardinality()) + + // Test with size at word boundary + val bitSet64 = BitSet(64) + assertEquals(64, bitSet64.size) + assertEquals(0, bitSet64.cardinality()) + + // Test with size just over word boundary + val bitSet65 = BitSet(65) + assertEquals(65, bitSet65.size) + assertEquals(0, bitSet65.cardinality()) + } + + @Test + fun testSetAndGet() { + val bitSet = BitSet(100) + + // Initially all bits should be unset + for (i in 0 until 100) { + assertFalse(bitSet[i], "Bit $i should be initially unset") + } + + // Set some bits + bitSet.set(0) + bitSet.set(1) + bitSet.set(63) + bitSet.set(64) + bitSet.set(99) + + // Verify the bits are set + assertTrue(bitSet[0], "Bit 0 should be set") + assertTrue(bitSet[1], "Bit 1 should be set") + assertTrue(bitSet[63], "Bit 63 should be set") + assertTrue(bitSet[64], "Bit 64 should be set") + assertTrue(bitSet[99], "Bit 99 should be set") + + // Verify other bits are still unset + assertFalse(bitSet[2], "Bit 2 should be unset") + assertFalse(bitSet[62], "Bit 62 should be unset") + assertFalse(bitSet[65], "Bit 65 should be unset") + assertFalse(bitSet[98], "Bit 98 should be unset") + } + + @Test + fun testClear() { + val bitSet = BitSet(100) + + // Set all bits + for (i in 0 until 100) { + bitSet.set(i) + } + + // Verify all bits are set + for (i in 0 until 100) { + assertTrue(bitSet[i], "Bit $i should be set") + } + + // Clear some bits + bitSet.clear(0) + bitSet.clear(1) + bitSet.clear(63) + bitSet.clear(64) + bitSet.clear(99) + + // Verify the bits are cleared + assertFalse(bitSet[0], "Bit 0 should be cleared") + assertFalse(bitSet[1], "Bit 1 should be cleared") + assertFalse(bitSet[63], "Bit 63 should be cleared") + assertFalse(bitSet[64], "Bit 64 should be cleared") + assertFalse(bitSet[99], "Bit 99 should be cleared") + + // Verify other bits are still set + assertTrue(bitSet[2], "Bit 2 should still be set") + assertTrue(bitSet[62], "Bit 62 should still be set") + assertTrue(bitSet[65], "Bit 65 should still be set") + assertTrue(bitSet[98], "Bit 98 should still be set") + } + + @Test + fun testClearAll() { + val bitSet = BitSet(100) + + // Set all bits + for (i in 0 until 100) { + bitSet.set(i) + } + + // Verify all bits are set + for (i in 0 until 100) { + assertTrue(bitSet[i], "Bit $i should be set") + } + + // Clear all bits + bitSet.clearAll() + + // Verify all bits are cleared + for (i in 0 until 100) { + assertFalse(bitSet[i], "Bit $i should be cleared after clearAll") + } + } + + @Test + fun testCardinality() { + val bitSet = BitSet(100) + assertEquals(0, bitSet.cardinality(), "Initial cardinality should be 0") + + // Set some bits + bitSet.set(0) + assertEquals(1, bitSet.cardinality(), "Cardinality should be 1 after setting 1 bit") + + bitSet.set(63) + assertEquals(2, bitSet.cardinality(), "Cardinality should be 2 after setting 2 bits") + + bitSet.set(64) + assertEquals(3, bitSet.cardinality(), "Cardinality should be 3 after setting 3 bits") + + bitSet.set(99) + assertEquals(4, bitSet.cardinality(), "Cardinality should be 4 after setting 4 bits") + + // Clear a bit + bitSet.clear(0) + assertEquals(3, bitSet.cardinality(), "Cardinality should be 3 after clearing 1 bit") + + // Set a bit that's already set + bitSet.set(63) + assertEquals(3, bitSet.cardinality(), "Cardinality should still be 3 after setting an already set bit") + + // Clear all bits + bitSet.clearAll() + assertEquals(0, bitSet.cardinality(), "Cardinality should be 0 after clearAll") + } + + @Test + fun testAllSet() { + // Test with empty BitSet + val emptyBitSet = BitSet(0) + assertTrue(emptyBitSet.allSet(), "Empty BitSet should return true for allSet") + + // Test with small BitSet + val smallBitSet = BitSet(5) + assertFalse(smallBitSet.allSet(), "New BitSet should return false for allSet") + + smallBitSet.set(0) + smallBitSet.set(1) + smallBitSet.set(2) + smallBitSet.set(3) + smallBitSet.set(4) + assertTrue(smallBitSet.allSet(), "BitSet with all bits set should return true for allSet") + + smallBitSet.clear(2) + assertFalse(smallBitSet.allSet(), "BitSet with one bit cleared should return false for allSet") + + // Test with BitSet that spans multiple words + val largeBitSet = BitSet(100) + assertFalse(largeBitSet.allSet(), "New large BitSet should return false for allSet") + + for (i in 0 until 100) { + largeBitSet.set(i) + } + assertTrue(largeBitSet.allSet(), "Large BitSet with all bits set should return true for allSet") + + largeBitSet.clear(63) + assertFalse(largeBitSet.allSet(), "Large BitSet with one bit cleared should return false for allSet") + + // Test with BitSet at word boundary + val wordBoundaryBitSet = BitSet(64) + assertFalse(wordBoundaryBitSet.allSet(), "New word boundary BitSet should return false for allSet") + + for (i in 0 until 64) { + wordBoundaryBitSet.set(i) + } + assertTrue(wordBoundaryBitSet.allSet(), "Word boundary BitSet with all bits set should return true for allSet") + } + + @Test + fun testEdgeCases() { + val bitSet = BitSet(100) + + // Test setting and getting at boundaries + bitSet.set(0) + assertTrue(bitSet[0], "Should be able to set and get bit 0") + + bitSet.set(99) + assertTrue(bitSet[99], "Should be able to set and get bit at size-1") + + // Test clearing at boundaries + bitSet.clear(0) + assertFalse(bitSet[0], "Should be able to clear bit 0") + + bitSet.clear(99) + assertFalse(bitSet[99], "Should be able to clear bit at size-1") + + // Test out of bounds access + assertFailsWith { + bitSet.set(100) + } + + assertFailsWith { + bitSet.clear(100) + } + + assertFailsWith { + bitSet[100] + } + + assertFailsWith { + bitSet.set(-1) + } + + assertFailsWith { + bitSet.clear(-1) + } + + assertFailsWith { + bitSet[-1] + } + } + + @Test + fun testWordBoundaries() { + // Test BitSet with size at word boundaries + for (size in listOf(63, 64, 65, 127, 128, 129)) { + val bitSet = BitSet(size) + + // Set all bits + for (i in 0 until size) { + bitSet.set(i) + } + + // Verify all bits are set + for (i in 0 until size) { + assertTrue(bitSet[i], "Bit $i should be set in BitSet of size $size") + } + + // Verify cardinality + assertEquals(size, bitSet.cardinality(), "Cardinality should equal size for fully set BitSet") + + // Verify allSet + assertTrue(bitSet.allSet(), "allSet should return true for fully set BitSet") + + // Clear all bits + bitSet.clearAll() + + // Verify all bits are cleared + for (i in 0 until size) { + assertFalse(bitSet[i], "Bit $i should be cleared in BitSet of size $size after clearAll") + } + + // Verify cardinality + assertEquals(0, bitSet.cardinality(), "Cardinality should be 0 after clearAll") + + // Verify allSet + assertFalse(bitSet.allSet(), "allSet should return false after clearAll") + } + } + + @Test + fun testLargeCardinality() { + // Test with a large BitSet to verify cardinality calculation + val size = 1000 + val bitSet = BitSet(size) + + // Set every other bit + for (i in 0 until size step 2) { + bitSet.set(i) + } + + // Verify cardinality + assertEquals(size / 2, bitSet.cardinality(), "Cardinality should be half the size when every other bit is set") + + // Set all bits + for (i in 0 until size) { + bitSet.set(i) + } + + // Verify cardinality + assertEquals(size, bitSet.cardinality(), "Cardinality should equal size when all bits are set") + } +} \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index 32dac983..bcf7d4c3 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -8,6 +8,7 @@ import kotlinx.io.Buffer import kotlinx.rpc.grpc.test.common.* import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertFailsWith class ProtosTest { @@ -49,7 +50,7 @@ class ProtosTest { val msgObj = msg as Message - val decoded = decodeEncode(msgObj, AllPrimitivesCommonBuilder::decodeWith) + val decoded = decodeEncode(msgObj, AllPrimitivesCommonInternal::decodeWith) as AllPrimitivesCommon assertEquals(msg.double, decoded.double) @@ -58,16 +59,43 @@ class ProtosTest { @Test fun testRepeatedProto() { val msg = RepeatedCommon { - listFixed32 = listOf(1, 2, 3).map { it.toUInt() } - listInt32 = listOf(4, 5, 6) + listFixed32 = listOf(1, 5, 3).map { it.toUInt() } + listFixed32Packed = listOf(1, 2, 3).map { it.toUInt() } + listInt32 = listOf(4, 7, 6) + listInt32Packed = listOf(4, 5, 6) listString = listOf("a", "b", "c") } - val decoded = decodeEncode(msg as Message, RepeatedCommonBuilder::decodeWith) as RepeatedCommonBuilder + val decoded = decodeEncode(msg as Message, RepeatedCommonInternal::decodeWith) as RepeatedCommonInternal assertEquals(msg.listInt32, decoded.listInt32) assertEquals(msg.listFixed32, decoded.listFixed32) assertEquals(msg.listString, decoded.listString) } -} + @Test + fun testPresenceCheckProto() { + + // Check a missing required field in a user-constructed message + val presenceCheck = PresenceCheck { + // net no fields + } + assertFailsWith("PresenceCheck is missing required field: RequiredPresence") { + (presenceCheck as Message).encodeWith(WireEncoder(Buffer())) + } + + // Test missing field during decoding of an encoded message + val buffer = Buffer() + val encoder = WireEncoder(buffer) + encoder.writeFloat(2, 1f) + encoder.flush() + + assertFailsWith("PresenceCheck is missing required field: RequiredPresence") { + WireDecoder(buffer).use { + PresenceCheckInternal.decodeWith(it) + } + } + } + + +} \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/proto/mini.proto b/grpc/grpc-core/src/commonTest/proto/mini.proto deleted file mode 100644 index 8ef15f6d..00000000 --- a/grpc/grpc-core/src/commonTest/proto/mini.proto +++ /dev/null @@ -1,8 +0,0 @@ -syntax = "proto3"; - -package kotlinx.rpc.grpc.test.common; - -message MiniMsg { - int32 MiniField = 1; -} - diff --git a/grpc/grpc-core/src/commonTest/proto/presence_check.proto b/grpc/grpc-core/src/commonTest/proto/presence_check.proto new file mode 100644 index 00000000..f428ce2c --- /dev/null +++ b/grpc/grpc-core/src/commonTest/proto/presence_check.proto @@ -0,0 +1,9 @@ +syntax = "proto2"; + +package kotlinx.rpc.grpc.test.common; + +message PresenceCheck { + required int32 RequiredPresence = 1; + optional float OptionalPresence = 2; +} + diff --git a/grpc/grpc-core/src/commonTest/proto/repeated.proto b/grpc/grpc-core/src/commonTest/proto/repeated.proto index 93c7ce58..61a858ff 100644 --- a/grpc/grpc-core/src/commonTest/proto/repeated.proto +++ b/grpc/grpc-core/src/commonTest/proto/repeated.proto @@ -3,8 +3,10 @@ syntax = "proto3"; package kotlinx.rpc.grpc.test.common; message RepeatedCommon { - repeated fixed32 listFixed32 = 1 [packed = true]; - repeated int32 listInt32 = 2 [packed = false]; - repeated string listString = 3; + repeated fixed32 listFixed32 = 1 [packed = false]; + repeated fixed32 listFixed32Packed = 2 [packed = true]; + repeated int32 listInt32 = 3 [packed = false]; + repeated int32 listInt32Packed = 4 [packed = true]; + repeated string listString = 5; } diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt index b1a6cdef..de3f31e2 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt @@ -233,6 +233,7 @@ open class CodeGenerator( fun clazz( name: String, modifiers: String = "", + constructorModifiers: String = "", constructorArgs: List> = emptyList(), superTypes: List = emptyList(), annotations: List = emptyList(), @@ -258,8 +259,12 @@ open class CodeGenerator( "$arg$defaultString" } + val constructorModifiersTransformed = if (constructorModifiers.isEmpty()) "" else + " ${constructorModifiers.trim()} constructor " + when { shouldPutArgsOnNewLines && constructorArgsTransformed.isNotEmpty() -> { + append(constructorModifiersTransformed) append("(") newLine() withNextIndent { @@ -271,10 +276,15 @@ open class CodeGenerator( } constructorArgsTransformed.isNotEmpty() -> { + append(constructorModifiersTransformed) append("(") append(constructorArgsTransformed.joinToString(", ")) append(")") } + + constructorModifiersTransformed.isNotEmpty() -> { + append("$constructorModifiersTransformed()") + } } val superString = superTypes diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index bc88e144..38944c63 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -11,6 +11,7 @@ import kotlinx.rpc.protobuf.model.* import org.slf4j.Logger private const val RPC_INTERNAL_PACKAGE_SUFFIX = "_rpc_internal" +private const val MSG_INTERNAL_SUFFIX = "Internal" class ModelToKotlinCommonGenerator( private val model: Model, @@ -101,6 +102,10 @@ class ModelToKotlinCommonGenerator( fileDeclaration.messageDeclarations.forEach { generateMessageConstructor(it) } + + fileDeclaration.messageDeclarations.forEach { + generateRequiredCheck(it) + } } private fun MessageDeclaration.fields() = actualFields.map { @@ -138,11 +143,11 @@ class ModelToKotlinCommonGenerator( @Suppress("detekt.CyclomaticComplexMethod") private fun CodeGenerator.generateInternalMessage(declaration: MessageDeclaration) { - val builderClassName = "${declaration.name.simpleName}Builder" + val internalClassName = declaration.internalClassName() clazz( - name = builderClassName, + name = internalClassName, declarationType = DeclarationType.Class, - superTypes = listOf(declaration.name.safeFullName(), "Message()"), + superTypes = listOf(declaration.name.safeFullName(), "Message(${declaration.presenceMaskSize})"), ) { declaration.fields().forEach { (fieldDeclaration, field) -> val value = when { @@ -161,6 +166,12 @@ class ModelToKotlinCommonGenerator( } code("override var $fieldDeclaration $value") + if (field.presenceIdx != null) { + scope("set(value) ") { + code("presenceMask.set(${field.presenceIdx})") + code("field = value") + } + } newLine() } @@ -170,7 +181,7 @@ class ModelToKotlinCommonGenerator( generateMessageEncoder(declaration) - scope("companion object: Message.Companion<$builderClassName>") { + scope("companion object: Message.Companion<$internalClassName>") { generateMessageDecoder(declaration) } } @@ -179,20 +190,20 @@ class ModelToKotlinCommonGenerator( private fun CodeGenerator.generateMessageConstructor(declaration: MessageDeclaration) = function( name = "invoke", modifiers = "operator", - args = "body: ${declaration.name.safeFullName("Builder")}.() -> Unit", + args = "body: ${declaration.internalClassFullName()}.() -> Unit", contextReceiver = "${declaration.name.safeFullName()}.Companion", returnType = declaration.name.safeFullName(), ) { - code("return ${declaration.name.safeFullName("Builder")}().apply(body)") + code("return ${declaration.internalClassFullName()}().apply(body)") } private fun CodeGenerator.generateMessageDecoder(declaration: MessageDeclaration) = function( name = "decodeWith", args = "decoder: WireDecoder", modifiers = "override", - returnType = declaration.name.simpleName + "Builder" + returnType = declaration.internalClassName() ) { - code("val msg = ${declaration.name.safeFullName("Builder")}()") + code("val msg = ${declaration.internalClassFullName()}()") whileBlock("!decoder.hadError()") { code("val tag = decoder.readTag() ?: break // EOF, we read the whole message") whenBlock { @@ -205,6 +216,8 @@ class ModelToKotlinCommonGenerator( ifBlock = { code("error(\"Error during decoding of ${declaration.name.simpleName}\")") } ) + code("msg.checkRequiredFields()") + // TODO: Make a lists immutable code("return msg") } @@ -222,7 +235,7 @@ class ModelToKotlinCommonGenerator( code("$assignment decoder.readPacked${fieldType.value.decodeEncodeFuncName()}()") } } else { - whenCase("tag.fieldNr == ${field.number} && tag.wireType == WireType.LENGTH_DELIMITED") { + whenCase("tag.fieldNr == ${field.number} && tag.wireType == WireType.${fieldType.value.wireType.name}") { code("(msg.${field.name} as ArrayList).add(decoder.read${fieldType.value.decodeEncodeFuncName()}())") } } @@ -243,6 +256,9 @@ class ModelToKotlinCommonGenerator( return@function } + // check if the user set all required fields + code("checkRequiredFields()") + declaration.fields().forEach { (_, field) -> val fieldName = field.name if (field.nullable) { @@ -283,6 +299,31 @@ class ModelToKotlinCommonGenerator( } } + + /** + * Generates a function to check for the presence of all required fields in a message declaration. + */ + private fun CodeGenerator.generateRequiredCheck(declaration: MessageDeclaration) = function( + name = "checkRequiredFields", + modifiers = "private", + contextReceiver = declaration.internalClassFullName(), + ) { + val requiredFields = declaration.actualFields + .filter { it.dec.isRequired } + + if (requiredFields.isEmpty()) { + code("// no fields to check") + return@function + } + + requiredFields.forEach { field -> + ifBranch(condition = "!presenceMask[${field.presenceIdx}]", ifBlock = { + code("error(\"${declaration.name.simpleName} is missing required field: ${field.name}\")") + }) + } + } + + private fun FieldDeclaration.wireSizeCall(variable: String): String { val sizeFunc = "WireSize.${type.decodeEncodeFuncName().replaceFirstChar { it.lowercase() }}($variable)" return when (val fieldType = type) { @@ -473,8 +514,17 @@ class ModelToKotlinCommonGenerator( } } } + + private fun MessageDeclaration.internalClassFullName(): String { + return name.safeFullName(MSG_INTERNAL_SUFFIX) + } + + private fun MessageDeclaration.internalClassName(): String { + return name.simpleName + MSG_INTERNAL_SUFFIX + } } + private fun String.packageNameSuffixed(suffix: String): String { return if (isEmpty()) suffix else "$this.$suffix" } diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt index 3f170270..a357d594 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt @@ -111,10 +111,18 @@ private fun Descriptors.FileDescriptor.toModel(): FileDeclaration = cached { } private fun Descriptors.Descriptor.toModel(): MessageDeclaration = cached { - val regularFields = fields.filter { field -> field.realContainingOneof == null }.map { it.toModel() } + var currPresenceIdx = 0 + val regularFields = fields + // only fields that are not part of a oneOf declaration + .filter { field -> field.realContainingOneof == null } + .map { + val presenceIdx = if (it.hasPresence()) currPresenceIdx++ else null + it.toModel(presenceIdx = presenceIdx) + } return MessageDeclaration( name = fqName(), + presenceMaskSize = currPresenceIdx, actualFields = regularFields, // get all oneof declarations that are not created from an optional in proto3 https://github.com/googleapis/api-linter/issues/1323 oneOfDeclarations = oneofs.filter { it.fields[0].realContainingOneof != null }.map { it.toModel() }, @@ -125,11 +133,12 @@ private fun Descriptors.Descriptor.toModel(): MessageDeclaration = cached { ) } -private fun Descriptors.FieldDescriptor.toModel(): FieldDeclaration = cached { +private fun Descriptors.FieldDescriptor.toModel(presenceIdx: Int? = null): FieldDeclaration = cached { toProto().hasProto3Optional() return FieldDeclaration( name = fqName().simpleName, type = modelType(), + presenceIdx = presenceIdx, doc = null, dec = this, ) diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt index 70f093ea..8eff1407 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt @@ -23,6 +23,7 @@ data class FileDeclaration( data class MessageDeclaration( val name: FqName, + val presenceMaskSize: Int, val actualFields: List, // excludes oneOf fields, but includes oneOf itself val oneOfDeclarations: List, val enumDeclarations: List, @@ -62,7 +63,10 @@ data class FieldDeclaration( val name: String, val type: FieldType, val doc: String?, - val dec: Descriptors.FieldDescriptor + val dec: Descriptors.FieldDescriptor, + // defines the index in the presenceMask of the Message. + // this cannot be the number, as only fields with hasPresence == true are part of the presenceMask + val presenceIdx: Int? = null ) { val packedFixedSize = type.wireType == WireType.FIXED64 || type.wireType == WireType.FIXED32 From 545cd6138b5742525868ed7e7dbe4ba243974b5d Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Fri, 1 Aug 2025 12:58:12 +0200 Subject: [PATCH 3/9] grpc-pb: Add PresenceIndices object that holds the presence indices of all fields. Signed-off-by: Johannes Zottele --- .../protobuf/ModelToKotlinCommonGenerator.kt | 31 ++++++++++++++++--- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index 38944c63..d6c0040e 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -146,9 +146,16 @@ class ModelToKotlinCommonGenerator( val internalClassName = declaration.internalClassName() clazz( name = internalClassName, + annotations = listOf("@kotlinx.rpc.internal.utils.InternalRpcApi"), declarationType = DeclarationType.Class, - superTypes = listOf(declaration.name.safeFullName(), "Message(${declaration.presenceMaskSize})"), + superTypes = listOf( + declaration.name.safeFullName(), + "Message(fieldsWithPresence = ${declaration.presenceMaskSize})" + ), ) { + + generatePresenceIndicesObject(declaration) + declaration.fields().forEach { (fieldDeclaration, field) -> val value = when { field.nullable -> { @@ -168,7 +175,7 @@ class ModelToKotlinCommonGenerator( code("override var $fieldDeclaration $value") if (field.presenceIdx != null) { scope("set(value) ") { - code("presenceMask.set(${field.presenceIdx})") + code("presenceMask.set(PresenceIndices.${field.name})") code("field = value") } } @@ -187,6 +194,20 @@ class ModelToKotlinCommonGenerator( } } + private fun CodeGenerator.generatePresenceIndicesObject(declaration: MessageDeclaration) { + if (declaration.presenceMaskSize == 0) { + return + } + scope("private object PresenceIndices") { + declaration.fields().forEach { (_, field) -> + if (field.presenceIdx != null) { + code("const val ${field.name} = ${field.presenceIdx}") + newLine() + } + } + } + } + private fun CodeGenerator.generateMessageConstructor(declaration: MessageDeclaration) = function( name = "invoke", modifiers = "operator", @@ -277,13 +298,13 @@ class ModelToKotlinCommonGenerator( private fun FieldDeclaration.writeValue(variable: String): String { return when (val fieldType = type) { - is FieldType.IntegralType -> "encoder.write${type.decodeEncodeFuncName()}($number, $variable)" + is FieldType.IntegralType -> "encoder.write${type.decodeEncodeFuncName()}(fieldNr = $number, value = $variable)" is FieldType.List -> when { dec.isPacked && packedFixedSize -> - "encoder.writePacked${fieldType.value.decodeEncodeFuncName()}($number, $variable)" + "encoder.writePacked${fieldType.value.decodeEncodeFuncName()}(fieldNr = $number, value = $variable)" dec.isPacked && !packedFixedSize -> - "encoder.writePacked${fieldType.value.decodeEncodeFuncName()}($number, $variable, ${ + "encoder.writePacked${fieldType.value.decodeEncodeFuncName()}(fieldNr = $number, value = $variable, size = ${ wireSizeCall( variable ) From 64aa73faa2d7e8849ef75b913cba98a255832831 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Fri, 1 Aug 2025 16:37:28 +0200 Subject: [PATCH 4/9] grpc-pb: Move BitSet to utils Signed-off-by: Johannes Zottele --- .../kotlin/kotlinx/rpc/grpc/pb/Message.kt | 2 +- .../rpc/grpc/{internal => utils}/BitSet.kt | 5 +- .../kotlinx/rpc/grpc/internal/BitSetTest.kt | 101 +++++++++--------- 3 files changed, 56 insertions(+), 52 deletions(-) rename grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/{internal => utils}/BitSet.kt (95%) diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/Message.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/Message.kt index f9b57a8d..0196312f 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/Message.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/Message.kt @@ -4,7 +4,7 @@ package kotlinx.rpc.grpc.pb -import kotlinx.rpc.grpc.internal.BitSet +import kotlinx.rpc.grpc.utils.BitSet import kotlinx.rpc.internal.utils.InternalRpcApi @InternalRpcApi diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/BitSet.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/utils/BitSet.kt similarity index 95% rename from grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/BitSet.kt rename to grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/utils/BitSet.kt index 16e7bce1..8989e5db 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/BitSet.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/utils/BitSet.kt @@ -2,11 +2,14 @@ * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. */ -package kotlinx.rpc.grpc.internal +package kotlinx.rpc.grpc.utils + +import kotlinx.rpc.internal.utils.InternalRpcApi /** * A fixed-sized vector of bits, allowing one to set/clear/read bits from it by a bit index. */ +@InternalRpcApi public class BitSet(public val size: Int) { private val data: LongArray = LongArray((size + 63) ushr 6) diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/internal/BitSetTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/internal/BitSetTest.kt index 1ab39751..22ac4c5a 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/internal/BitSetTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/internal/BitSetTest.kt @@ -4,6 +4,7 @@ package kotlinx.rpc.grpc.internal +import kotlinx.rpc.grpc.utils.BitSet import kotlin.test.* class BitSetTest { @@ -39,26 +40,26 @@ class BitSetTest { @Test fun testSetAndGet() { val bitSet = BitSet(100) - + // Initially all bits should be unset for (i in 0 until 100) { assertFalse(bitSet[i], "Bit $i should be initially unset") } - + // Set some bits bitSet.set(0) bitSet.set(1) bitSet.set(63) bitSet.set(64) bitSet.set(99) - + // Verify the bits are set assertTrue(bitSet[0], "Bit 0 should be set") assertTrue(bitSet[1], "Bit 1 should be set") assertTrue(bitSet[63], "Bit 63 should be set") assertTrue(bitSet[64], "Bit 64 should be set") assertTrue(bitSet[99], "Bit 99 should be set") - + // Verify other bits are still unset assertFalse(bitSet[2], "Bit 2 should be unset") assertFalse(bitSet[62], "Bit 62 should be unset") @@ -69,31 +70,31 @@ class BitSetTest { @Test fun testClear() { val bitSet = BitSet(100) - + // Set all bits for (i in 0 until 100) { bitSet.set(i) } - + // Verify all bits are set for (i in 0 until 100) { assertTrue(bitSet[i], "Bit $i should be set") } - + // Clear some bits bitSet.clear(0) bitSet.clear(1) bitSet.clear(63) bitSet.clear(64) bitSet.clear(99) - + // Verify the bits are cleared assertFalse(bitSet[0], "Bit 0 should be cleared") assertFalse(bitSet[1], "Bit 1 should be cleared") assertFalse(bitSet[63], "Bit 63 should be cleared") assertFalse(bitSet[64], "Bit 64 should be cleared") assertFalse(bitSet[99], "Bit 99 should be cleared") - + // Verify other bits are still set assertTrue(bitSet[2], "Bit 2 should still be set") assertTrue(bitSet[62], "Bit 62 should still be set") @@ -104,20 +105,20 @@ class BitSetTest { @Test fun testClearAll() { val bitSet = BitSet(100) - + // Set all bits for (i in 0 until 100) { bitSet.set(i) } - + // Verify all bits are set for (i in 0 until 100) { assertTrue(bitSet[i], "Bit $i should be set") } - + // Clear all bits bitSet.clearAll() - + // Verify all bits are cleared for (i in 0 until 100) { assertFalse(bitSet[i], "Bit $i should be cleared after clearAll") @@ -128,28 +129,28 @@ class BitSetTest { fun testCardinality() { val bitSet = BitSet(100) assertEquals(0, bitSet.cardinality(), "Initial cardinality should be 0") - + // Set some bits bitSet.set(0) assertEquals(1, bitSet.cardinality(), "Cardinality should be 1 after setting 1 bit") - + bitSet.set(63) assertEquals(2, bitSet.cardinality(), "Cardinality should be 2 after setting 2 bits") - + bitSet.set(64) assertEquals(3, bitSet.cardinality(), "Cardinality should be 3 after setting 3 bits") - + bitSet.set(99) assertEquals(4, bitSet.cardinality(), "Cardinality should be 4 after setting 4 bits") - + // Clear a bit bitSet.clear(0) assertEquals(3, bitSet.cardinality(), "Cardinality should be 3 after clearing 1 bit") - + // Set a bit that's already set bitSet.set(63) assertEquals(3, bitSet.cardinality(), "Cardinality should still be 3 after setting an already set bit") - + // Clear all bits bitSet.clearAll() assertEquals(0, bitSet.cardinality(), "Cardinality should be 0 after clearAll") @@ -160,37 +161,37 @@ class BitSetTest { // Test with empty BitSet val emptyBitSet = BitSet(0) assertTrue(emptyBitSet.allSet(), "Empty BitSet should return true for allSet") - + // Test with small BitSet val smallBitSet = BitSet(5) assertFalse(smallBitSet.allSet(), "New BitSet should return false for allSet") - + smallBitSet.set(0) smallBitSet.set(1) smallBitSet.set(2) smallBitSet.set(3) smallBitSet.set(4) assertTrue(smallBitSet.allSet(), "BitSet with all bits set should return true for allSet") - + smallBitSet.clear(2) assertFalse(smallBitSet.allSet(), "BitSet with one bit cleared should return false for allSet") - + // Test with BitSet that spans multiple words val largeBitSet = BitSet(100) assertFalse(largeBitSet.allSet(), "New large BitSet should return false for allSet") - + for (i in 0 until 100) { largeBitSet.set(i) } assertTrue(largeBitSet.allSet(), "Large BitSet with all bits set should return true for allSet") - + largeBitSet.clear(63) assertFalse(largeBitSet.allSet(), "Large BitSet with one bit cleared should return false for allSet") - + // Test with BitSet at word boundary val wordBoundaryBitSet = BitSet(64) assertFalse(wordBoundaryBitSet.allSet(), "New word boundary BitSet should return false for allSet") - + for (i in 0 until 64) { wordBoundaryBitSet.set(i) } @@ -200,42 +201,42 @@ class BitSetTest { @Test fun testEdgeCases() { val bitSet = BitSet(100) - + // Test setting and getting at boundaries bitSet.set(0) assertTrue(bitSet[0], "Should be able to set and get bit 0") - + bitSet.set(99) assertTrue(bitSet[99], "Should be able to set and get bit at size-1") - + // Test clearing at boundaries bitSet.clear(0) assertFalse(bitSet[0], "Should be able to clear bit 0") - + bitSet.clear(99) assertFalse(bitSet[99], "Should be able to clear bit at size-1") - + // Test out of bounds access assertFailsWith { bitSet.set(100) } - + assertFailsWith { bitSet.clear(100) } - + assertFailsWith { bitSet[100] } - + assertFailsWith { bitSet.set(-1) } - + assertFailsWith { bitSet.clear(-1) } - + assertFailsWith { bitSet[-1] } @@ -246,34 +247,34 @@ class BitSetTest { // Test BitSet with size at word boundaries for (size in listOf(63, 64, 65, 127, 128, 129)) { val bitSet = BitSet(size) - + // Set all bits for (i in 0 until size) { bitSet.set(i) } - + // Verify all bits are set for (i in 0 until size) { assertTrue(bitSet[i], "Bit $i should be set in BitSet of size $size") } - + // Verify cardinality assertEquals(size, bitSet.cardinality(), "Cardinality should equal size for fully set BitSet") - + // Verify allSet assertTrue(bitSet.allSet(), "allSet should return true for fully set BitSet") - + // Clear all bits bitSet.clearAll() - + // Verify all bits are cleared for (i in 0 until size) { assertFalse(bitSet[i], "Bit $i should be cleared in BitSet of size $size after clearAll") } - + // Verify cardinality assertEquals(0, bitSet.cardinality(), "Cardinality should be 0 after clearAll") - + // Verify allSet assertFalse(bitSet.allSet(), "allSet should return false after clearAll") } @@ -284,20 +285,20 @@ class BitSetTest { // Test with a large BitSet to verify cardinality calculation val size = 1000 val bitSet = BitSet(size) - + // Set every other bit for (i in 0 until size step 2) { bitSet.set(i) } - + // Verify cardinality assertEquals(size / 2, bitSet.cardinality(), "Cardinality should be half the size when every other bit is set") - + // Set all bits for (i in 0 until size) { bitSet.set(i) } - + // Verify cardinality assertEquals(size, bitSet.cardinality(), "Cardinality should equal size when all bits are set") } From 13a3d37909a3b5425fa1ceb16bf50934c5ea018d Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Fri, 1 Aug 2025 19:13:47 +0200 Subject: [PATCH 5/9] grpc-pb: Add MessageCodec object for each message Signed-off-by: Johannes Zottele --- .../conventions-kotlin-version.gradle.kts | 1 - .../pb/{Message.kt => InternalMessage.kt} | 11 +--- .../kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt | 36 +++++-------- .../src/commonTest/proto/repeated.proto | 1 - .../protobuf/ModelToKotlinCommonGenerator.kt | 51 ++++++++++++++----- 5 files changed, 52 insertions(+), 48 deletions(-) rename grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/{Message.kt => InternalMessage.kt} (57%) diff --git a/gradle-conventions/src/main/kotlin/conventions-kotlin-version.gradle.kts b/gradle-conventions/src/main/kotlin/conventions-kotlin-version.gradle.kts index 5994aa8e..77e18b3a 100644 --- a/gradle-conventions/src/main/kotlin/conventions-kotlin-version.gradle.kts +++ b/gradle-conventions/src/main/kotlin/conventions-kotlin-version.gradle.kts @@ -5,7 +5,6 @@ import org.jetbrains.kotlin.gradle.dsl.KotlinCommonCompilerOptions import org.jetbrains.kotlin.gradle.dsl.KotlinProjectExtension import org.jetbrains.kotlin.gradle.dsl.KotlinVersion -import org.jetbrains.kotlin.gradle.plugin.KotlinCompilation import util.withKotlinJvmExtension import util.withKotlinKmpExtension diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/Message.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt similarity index 57% rename from grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/Message.kt rename to grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt index 0196312f..d65e74ce 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/Message.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt @@ -8,16 +8,7 @@ import kotlinx.rpc.grpc.utils.BitSet import kotlinx.rpc.internal.utils.InternalRpcApi @InternalRpcApi -public abstract class Message(fieldsWithPresence: Int) { - +public abstract class InternalMessage(fieldsWithPresence: Int) { public val presenceMask: BitSet = BitSet(fieldsWithPresence) - @InternalRpcApi - public interface Companion { - - public fun decodeWith(decoder: WireDecoder): T - - } - - public abstract fun encodeWith(encoder: WireEncoder) } \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index bcf7d4c3..2cb1791d 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -5,6 +5,7 @@ package kotlinx.rpc.grpc.pb import kotlinx.io.Buffer +import kotlinx.rpc.grpc.internal.MessageCodec import kotlinx.rpc.grpc.test.common.* import kotlin.test.Test import kotlin.test.assertEquals @@ -12,19 +13,12 @@ import kotlin.test.assertFailsWith class ProtosTest { - private fun decodeEncode( - msg: T, - decoder: (WireDecoder) -> T, - ): T { - val buffer = Buffer() - val encoder = WireEncoder(buffer) - - msg.encodeWith(encoder) - encoder.flush() - - return WireDecoder(buffer).use { - decoder(it) - } + private fun decodeEncode( + msg: M, + codec: MessageCodec + ): M { + val source = codec.encode(msg) + return codec.decode(source) } @@ -48,10 +42,9 @@ class ProtosTest { bytes = byteArrayOf(1, 2, 3) } - val msgObj = msg as Message + val msgObj = msg - val decoded = decodeEncode(msgObj, AllPrimitivesCommonInternal::decodeWith) - as AllPrimitivesCommon + val decoded = decodeEncode(msgObj, AllPrimitivesCommonInternal.CODEC) assertEquals(msg.double, decoded.double) } @@ -66,7 +59,7 @@ class ProtosTest { listString = listOf("a", "b", "c") } - val decoded = decodeEncode(msg as Message, RepeatedCommonInternal::decodeWith) as RepeatedCommonInternal + val decoded = decodeEncode(msg, RepeatedCommonInternal.CODEC) assertEquals(msg.listInt32, decoded.listInt32) assertEquals(msg.listFixed32, decoded.listFixed32) @@ -77,11 +70,8 @@ class ProtosTest { fun testPresenceCheckProto() { // Check a missing required field in a user-constructed message - val presenceCheck = PresenceCheck { - // net no fields - } assertFailsWith("PresenceCheck is missing required field: RequiredPresence") { - (presenceCheck as Message).encodeWith(WireEncoder(Buffer())) + PresenceCheck {} } // Test missing field during decoding of an encoded message @@ -91,9 +81,7 @@ class ProtosTest { encoder.flush() assertFailsWith("PresenceCheck is missing required field: RequiredPresence") { - WireDecoder(buffer).use { - PresenceCheckInternal.decodeWith(it) - } + PresenceCheckInternal.CODEC.decode(buffer) } } diff --git a/grpc/grpc-core/src/commonTest/proto/repeated.proto b/grpc/grpc-core/src/commonTest/proto/repeated.proto index 61a858ff..9f33c35f 100644 --- a/grpc/grpc-core/src/commonTest/proto/repeated.proto +++ b/grpc/grpc-core/src/commonTest/proto/repeated.proto @@ -9,4 +9,3 @@ message RepeatedCommon { repeated int32 listInt32Packed = 4 [packed = true]; repeated string listString = 5; } - diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index d6c0040e..72227a33 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -105,7 +105,10 @@ class ModelToKotlinCommonGenerator( fileDeclaration.messageDeclarations.forEach { generateRequiredCheck(it) + generateMessageEncoder(it) + generateMessageDecoder(it) } + } private fun MessageDeclaration.fields() = actualFields.map { @@ -150,7 +153,7 @@ class ModelToKotlinCommonGenerator( declarationType = DeclarationType.Class, superTypes = listOf( declaration.name.safeFullName(), - "Message(fieldsWithPresence = ${declaration.presenceMaskSize})" + "InternalMessage(fieldsWithPresence = ${declaration.presenceMaskSize})" ), ) { @@ -186,10 +189,8 @@ class ModelToKotlinCommonGenerator( generateInternalMessage(nested) } - generateMessageEncoder(declaration) - - scope("companion object: Message.Companion<$internalClassName>") { - generateMessageDecoder(declaration) + scope("companion object") { + generateCodecObject(declaration) } } } @@ -208,6 +209,30 @@ class ModelToKotlinCommonGenerator( } } + private fun CodeGenerator.generateCodecObject(declaration: MessageDeclaration) { + val msgFqName = declaration.name.safeFullName() + val downCastErrorStr = "The message is a custom implementation of ${msgFqName}. This is currently not allowed." + val sourceFqName = "kotlinx.io.Source" + val bufferFqName = "kotlinx.io.Buffer" + scope("val CODEC = object : kotlinx.rpc.grpc.internal.MessageCodec<$msgFqName>") { + function("encode", modifiers = "override", args = "value: $msgFqName", returnType = "$sourceFqName") { + code("val msg = value as? ${declaration.internalClassFullName()} ?: error(\"$downCastErrorStr\")") + code("val buffer = $bufferFqName()") + code("val encoder = WireEncoder(buffer)") + code("msg.encodeWith(encoder)") + code("encoder.flush()") + code("return buffer") + } + + + function("decode", modifiers = "override", args = "stream: $sourceFqName", returnType = msgFqName) { + scope("WireDecoder(stream as $bufferFqName).use") { + code("return ${declaration.internalClassFullName()}.decodeWith(it)") + } + } + } + } + private fun CodeGenerator.generateMessageConstructor(declaration: MessageDeclaration) = function( name = "invoke", modifiers = "operator", @@ -215,13 +240,17 @@ class ModelToKotlinCommonGenerator( contextReceiver = "${declaration.name.safeFullName()}.Companion", returnType = declaration.name.safeFullName(), ) { - code("return ${declaration.internalClassFullName()}().apply(body)") + code("val msg = ${declaration.internalClassFullName()}().apply(body)") + // check if the user set all required fields + code("msg.checkRequiredFields()") + code("return msg") } private fun CodeGenerator.generateMessageDecoder(declaration: MessageDeclaration) = function( name = "decodeWith", + modifiers = "private", args = "decoder: WireDecoder", - modifiers = "override", + contextReceiver = "${declaration.internalClassFullName()}.Companion", returnType = declaration.internalClassName() ) { code("val msg = ${declaration.internalClassFullName()}()") @@ -269,17 +298,15 @@ class ModelToKotlinCommonGenerator( private fun CodeGenerator.generateMessageEncoder(declaration: MessageDeclaration) = function( name = "encodeWith", + modifiers = "private", args = "encoder: WireEncoder", - modifiers = "override" + contextReceiver = declaration.internalClassFullName(), ) { if (declaration.fields().isEmpty()) { code("// no fields to encode") return@function } - // check if the user set all required fields - code("checkRequiredFields()") - declaration.fields().forEach { (_, field) -> val fieldName = field.name if (field.nullable) { @@ -304,7 +331,7 @@ class ModelToKotlinCommonGenerator( "encoder.writePacked${fieldType.value.decodeEncodeFuncName()}(fieldNr = $number, value = $variable)" dec.isPacked && !packedFixedSize -> - "encoder.writePacked${fieldType.value.decodeEncodeFuncName()}(fieldNr = $number, value = $variable, size = ${ + "encoder.writePacked${fieldType.value.decodeEncodeFuncName()}(fieldNr = $number, value = $variable, fieldSize = ${ wireSizeCall( variable ) From adc277ff0b3702f782b21f08cd2242968f40b31a Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Fri, 1 Aug 2025 19:27:43 +0200 Subject: [PATCH 6/9] grpc-pb: Use only fully qualified names for kotlinx.rpc.grpc.pb.* classes Signed-off-by: Johannes Zottele --- .../conventions-kotlin-version.gradle.kts | 4 ++-- .../kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt | 6 ++++++ .../protobuf/ModelToKotlinCommonGenerator.kt | 19 +++++++++---------- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/gradle-conventions/src/main/kotlin/conventions-kotlin-version.gradle.kts b/gradle-conventions/src/main/kotlin/conventions-kotlin-version.gradle.kts index 77e18b3a..5c40a62e 100644 --- a/gradle-conventions/src/main/kotlin/conventions-kotlin-version.gradle.kts +++ b/gradle-conventions/src/main/kotlin/conventions-kotlin-version.gradle.kts @@ -24,8 +24,8 @@ fun KotlinProjectExtension.optInForRpcApi() { * This makes our tests execute against the latest compiler plugin version (for example, with K2 instead of K1). */ fun KotlinCommonCompilerOptions.setProjectLanguageVersion() { - languageVersion.set(KotlinVersion.KOTLIN_2_0) - apiVersion.set(KotlinVersion.KOTLIN_2_0) + languageVersion.set(KotlinVersion.KOTLIN_2_1) + apiVersion.set(KotlinVersion.KOTLIN_2_1) } withKotlinJvmExtension { diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index 2cb1791d..c60b3a28 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -21,6 +21,12 @@ class ProtosTest { return codec.decode(source) } + fun codecShowCase() { + val msg = AllPrimitivesCommon { + // set fields + } + val decodedMsg = decodeEncode(msg, AllPrimitivesCommonInternal.CODEC) + } @Test fun testAllPrimitiveProto() { diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index 72227a33..9471a899 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -12,6 +12,7 @@ import org.slf4j.Logger private const val RPC_INTERNAL_PACKAGE_SUFFIX = "_rpc_internal" private const val MSG_INTERNAL_SUFFIX = "Internal" +private const val PB_PKG = "kotlinx.rpc.grpc.pb" class ModelToKotlinCommonGenerator( private val model: Model, @@ -78,8 +79,6 @@ class ModelToKotlinCommonGenerator( import("kotlinx.rpc.internal.utils.*") import("kotlinx.coroutines.flow.*") - import("kotlinx.rpc.grpc.pb.*") - additionalInternalImports.forEach { import(it) @@ -153,7 +152,7 @@ class ModelToKotlinCommonGenerator( declarationType = DeclarationType.Class, superTypes = listOf( declaration.name.safeFullName(), - "InternalMessage(fieldsWithPresence = ${declaration.presenceMaskSize})" + "$PB_PKG.InternalMessage(fieldsWithPresence = ${declaration.presenceMaskSize})" ), ) { @@ -218,7 +217,7 @@ class ModelToKotlinCommonGenerator( function("encode", modifiers = "override", args = "value: $msgFqName", returnType = "$sourceFqName") { code("val msg = value as? ${declaration.internalClassFullName()} ?: error(\"$downCastErrorStr\")") code("val buffer = $bufferFqName()") - code("val encoder = WireEncoder(buffer)") + code("val encoder = $PB_PKG.WireEncoder(buffer)") code("msg.encodeWith(encoder)") code("encoder.flush()") code("return buffer") @@ -226,7 +225,7 @@ class ModelToKotlinCommonGenerator( function("decode", modifiers = "override", args = "stream: $sourceFqName", returnType = msgFqName) { - scope("WireDecoder(stream as $bufferFqName).use") { + scope("$PB_PKG.WireDecoder(stream as $bufferFqName).use") { code("return ${declaration.internalClassFullName()}.decodeWith(it)") } } @@ -249,7 +248,7 @@ class ModelToKotlinCommonGenerator( private fun CodeGenerator.generateMessageDecoder(declaration: MessageDeclaration) = function( name = "decodeWith", modifiers = "private", - args = "decoder: WireDecoder", + args = "decoder: $PB_PKG.WireDecoder", contextReceiver = "${declaration.internalClassFullName()}.Companion", returnType = declaration.internalClassName() ) { @@ -276,16 +275,16 @@ class ModelToKotlinCommonGenerator( val encFuncName = field.type.decodeEncodeFuncName() val assignment = "msg.${field.name} =" when (val fieldType = field.type) { - is FieldType.IntegralType -> whenCase("tag.fieldNr == ${field.number} && tag.wireType == WireType.${field.type.wireType.name}") { + is FieldType.IntegralType -> whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.${field.type.wireType.name}") { code("$assignment decoder.read$encFuncName()") } is FieldType.List -> if (field.dec.isPacked) { - whenCase("tag.fieldNr == ${field.number} && tag.wireType == WireType.LENGTH_DELIMITED") { + whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.LENGTH_DELIMITED") { code("$assignment decoder.readPacked${fieldType.value.decodeEncodeFuncName()}()") } } else { - whenCase("tag.fieldNr == ${field.number} && tag.wireType == WireType.${fieldType.value.wireType.name}") { + whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.${fieldType.value.wireType.name}") { code("(msg.${field.name} as ArrayList).add(decoder.read${fieldType.value.decodeEncodeFuncName()}())") } } @@ -299,7 +298,7 @@ class ModelToKotlinCommonGenerator( private fun CodeGenerator.generateMessageEncoder(declaration: MessageDeclaration) = function( name = "encodeWith", modifiers = "private", - args = "encoder: WireEncoder", + args = "encoder: $PB_PKG.WireEncoder", contextReceiver = declaration.internalClassFullName(), ) { if (declaration.fields().isEmpty()) { From f380a05500e52ce1652d73be262ead0cde320d3c Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Fri, 1 Aug 2025 19:47:12 +0200 Subject: [PATCH 7/9] Revert kotlin version increase Signed-off-by: Johannes Zottele --- .../src/main/kotlin/conventions-kotlin-version.gradle.kts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gradle-conventions/src/main/kotlin/conventions-kotlin-version.gradle.kts b/gradle-conventions/src/main/kotlin/conventions-kotlin-version.gradle.kts index 5c40a62e..77e18b3a 100644 --- a/gradle-conventions/src/main/kotlin/conventions-kotlin-version.gradle.kts +++ b/gradle-conventions/src/main/kotlin/conventions-kotlin-version.gradle.kts @@ -24,8 +24,8 @@ fun KotlinProjectExtension.optInForRpcApi() { * This makes our tests execute against the latest compiler plugin version (for example, with K2 instead of K1). */ fun KotlinCommonCompilerOptions.setProjectLanguageVersion() { - languageVersion.set(KotlinVersion.KOTLIN_2_1) - apiVersion.set(KotlinVersion.KOTLIN_2_1) + languageVersion.set(KotlinVersion.KOTLIN_2_0) + apiVersion.set(KotlinVersion.KOTLIN_2_0) } withKotlinJvmExtension { From 524f820c576cc7b90b9fc11d82030ad024523df6 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Fri, 1 Aug 2025 19:55:38 +0200 Subject: [PATCH 8/9] grpc-pb: Remove demo test Signed-off-by: Johannes Zottele --- .../commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt | 7 ------- 1 file changed, 7 deletions(-) diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index c60b3a28..ee4f39d2 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -21,13 +21,6 @@ class ProtosTest { return codec.decode(source) } - fun codecShowCase() { - val msg = AllPrimitivesCommon { - // set fields - } - val decodedMsg = decodeEncode(msg, AllPrimitivesCommonInternal.CODEC) - } - @Test fun testAllPrimitiveProto() { val msg = AllPrimitivesCommon { From a8a10fef0e0b7eb784b52755d74c8f855f1fa5b5 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Mon, 4 Aug 2025 16:49:57 +0200 Subject: [PATCH 9/9] grpc-pb: Address PR comments Signed-off-by: Johannes Zottele --- .../kotlinx/rpc/grpc/pb/InternalMessage.kt | 1 - .../kotlin/kotlinx/rpc/grpc/utils/BitSet.kt | 8 ++- .../kotlinx/rpc/grpc/internal/BitSetTest.kt | 62 +++++++++---------- .../kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt | 8 +-- .../src/commonTest/proto/all_primitives.proto | 4 +- .../src/commonTest/proto/repeated.proto | 4 +- .../protobuf/ModelToKotlinCommonGenerator.kt | 20 +++--- 7 files changed, 53 insertions(+), 54 deletions(-) diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt index d65e74ce..f9d07f01 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt @@ -10,5 +10,4 @@ import kotlinx.rpc.internal.utils.InternalRpcApi @InternalRpcApi public abstract class InternalMessage(fieldsWithPresence: Int) { public val presenceMask: BitSet = BitSet(fieldsWithPresence) - } \ No newline at end of file diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/utils/BitSet.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/utils/BitSet.kt index 8989e5db..85520ca0 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/utils/BitSet.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/utils/BitSet.kt @@ -14,10 +14,12 @@ public class BitSet(public val size: Int) { private val data: LongArray = LongArray((size + 63) ushr 6) /** Sets the bit at [index] to 1. */ - public fun set(index: Int) { - require(index >= 0 && index < size) { "Index $index out of bounds for length $size" } + public operator fun set(index: Int, value: Boolean) { + if (!value) return clear(index) + require(index in 0 until size) { "Index $index out‑of‑bounds for length $size" } val word = index ushr 6 - data[word] = data[word] or (1L shl (index and 63)) + val mask = 1L shl (index and 63) + data[word] = data[word] or mask } /** Clears the bit at [index] (sets to 0). */ diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/internal/BitSetTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/internal/BitSetTest.kt index 22ac4c5a..0644f91c 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/internal/BitSetTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/internal/BitSetTest.kt @@ -47,11 +47,11 @@ class BitSetTest { } // Set some bits - bitSet.set(0) - bitSet.set(1) - bitSet.set(63) - bitSet.set(64) - bitSet.set(99) + bitSet[0] = true + bitSet[1] = true + bitSet[63] = true + bitSet[64] = true + bitSet[99] = true // Verify the bits are set assertTrue(bitSet[0], "Bit 0 should be set") @@ -73,7 +73,7 @@ class BitSetTest { // Set all bits for (i in 0 until 100) { - bitSet.set(i) + bitSet[i] = true } // Verify all bits are set @@ -82,11 +82,11 @@ class BitSetTest { } // Clear some bits - bitSet.clear(0) - bitSet.clear(1) - bitSet.clear(63) - bitSet.clear(64) - bitSet.clear(99) + bitSet[0] = false + bitSet[1] = false + bitSet[63] = false + bitSet[64] = false + bitSet[99] = false // Verify the bits are cleared assertFalse(bitSet[0], "Bit 0 should be cleared") @@ -108,7 +108,7 @@ class BitSetTest { // Set all bits for (i in 0 until 100) { - bitSet.set(i) + bitSet[i] = true } // Verify all bits are set @@ -131,16 +131,16 @@ class BitSetTest { assertEquals(0, bitSet.cardinality(), "Initial cardinality should be 0") // Set some bits - bitSet.set(0) + bitSet[0] = true assertEquals(1, bitSet.cardinality(), "Cardinality should be 1 after setting 1 bit") - bitSet.set(63) + bitSet[63] = true assertEquals(2, bitSet.cardinality(), "Cardinality should be 2 after setting 2 bits") - bitSet.set(64) + bitSet[64] = true assertEquals(3, bitSet.cardinality(), "Cardinality should be 3 after setting 3 bits") - bitSet.set(99) + bitSet[99] = true assertEquals(4, bitSet.cardinality(), "Cardinality should be 4 after setting 4 bits") // Clear a bit @@ -148,7 +148,7 @@ class BitSetTest { assertEquals(3, bitSet.cardinality(), "Cardinality should be 3 after clearing 1 bit") // Set a bit that's already set - bitSet.set(63) + bitSet[63] = true assertEquals(3, bitSet.cardinality(), "Cardinality should still be 3 after setting an already set bit") // Clear all bits @@ -166,11 +166,11 @@ class BitSetTest { val smallBitSet = BitSet(5) assertFalse(smallBitSet.allSet(), "New BitSet should return false for allSet") - smallBitSet.set(0) - smallBitSet.set(1) - smallBitSet.set(2) - smallBitSet.set(3) - smallBitSet.set(4) + smallBitSet[0] = true + smallBitSet[1] = true + smallBitSet[2] = true + smallBitSet[3] = true + smallBitSet[4] = true assertTrue(smallBitSet.allSet(), "BitSet with all bits set should return true for allSet") smallBitSet.clear(2) @@ -181,7 +181,7 @@ class BitSetTest { assertFalse(largeBitSet.allSet(), "New large BitSet should return false for allSet") for (i in 0 until 100) { - largeBitSet.set(i) + largeBitSet[i] = true } assertTrue(largeBitSet.allSet(), "Large BitSet with all bits set should return true for allSet") @@ -193,7 +193,7 @@ class BitSetTest { assertFalse(wordBoundaryBitSet.allSet(), "New word boundary BitSet should return false for allSet") for (i in 0 until 64) { - wordBoundaryBitSet.set(i) + wordBoundaryBitSet[i] = true } assertTrue(wordBoundaryBitSet.allSet(), "Word boundary BitSet with all bits set should return true for allSet") } @@ -203,10 +203,10 @@ class BitSetTest { val bitSet = BitSet(100) // Test setting and getting at boundaries - bitSet.set(0) + bitSet[0] = true assertTrue(bitSet[0], "Should be able to set and get bit 0") - bitSet.set(99) + bitSet[99] = true assertTrue(bitSet[99], "Should be able to set and get bit at size-1") // Test clearing at boundaries @@ -218,7 +218,7 @@ class BitSetTest { // Test out of bounds access assertFailsWith { - bitSet.set(100) + bitSet[100] = true } assertFailsWith { @@ -230,7 +230,7 @@ class BitSetTest { } assertFailsWith { - bitSet.set(-1) + bitSet[-1] = true } assertFailsWith { @@ -250,7 +250,7 @@ class BitSetTest { // Set all bits for (i in 0 until size) { - bitSet.set(i) + bitSet[i] = true } // Verify all bits are set @@ -288,7 +288,7 @@ class BitSetTest { // Set every other bit for (i in 0 until size step 2) { - bitSet.set(i) + bitSet[i] = true } // Verify cardinality @@ -296,7 +296,7 @@ class BitSetTest { // Set all bits for (i in 0 until size) { - bitSet.set(i) + bitSet[i] = true } // Verify cardinality diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index ee4f39d2..86426cff 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -23,7 +23,7 @@ class ProtosTest { @Test fun testAllPrimitiveProto() { - val msg = AllPrimitivesCommon { + val msg = AllPrimitives { int32 = 12 int64 = 1234567890123456789L uint32 = 12345u @@ -43,14 +43,14 @@ class ProtosTest { val msgObj = msg - val decoded = decodeEncode(msgObj, AllPrimitivesCommonInternal.CODEC) + val decoded = decodeEncode(msgObj, AllPrimitivesInternal.CODEC) assertEquals(msg.double, decoded.double) } @Test fun testRepeatedProto() { - val msg = RepeatedCommon { + val msg = Repeated { listFixed32 = listOf(1, 5, 3).map { it.toUInt() } listFixed32Packed = listOf(1, 2, 3).map { it.toUInt() } listInt32 = listOf(4, 7, 6) @@ -58,7 +58,7 @@ class ProtosTest { listString = listOf("a", "b", "c") } - val decoded = decodeEncode(msg, RepeatedCommonInternal.CODEC) + val decoded = decodeEncode(msg, RepeatedInternal.CODEC) assertEquals(msg.listInt32, decoded.listInt32) assertEquals(msg.listFixed32, decoded.listFixed32) diff --git a/grpc/grpc-core/src/commonTest/proto/all_primitives.proto b/grpc/grpc-core/src/commonTest/proto/all_primitives.proto index 14772b74..99b12cd5 100644 --- a/grpc/grpc-core/src/commonTest/proto/all_primitives.proto +++ b/grpc/grpc-core/src/commonTest/proto/all_primitives.proto @@ -2,7 +2,7 @@ syntax = "proto3"; package kotlinx.rpc.grpc.test.common; -message AllPrimitivesCommon { +message AllPrimitives { double double = 1; float float = 2; int32 int32 = 3; @@ -18,4 +18,4 @@ message AllPrimitivesCommon { optional bool bool = 13; optional string string = 14; optional bytes bytes = 15; -} +} \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/proto/repeated.proto b/grpc/grpc-core/src/commonTest/proto/repeated.proto index 9f33c35f..e80c7b44 100644 --- a/grpc/grpc-core/src/commonTest/proto/repeated.proto +++ b/grpc/grpc-core/src/commonTest/proto/repeated.proto @@ -2,10 +2,10 @@ syntax = "proto3"; package kotlinx.rpc.grpc.test.common; -message RepeatedCommon { +message Repeated { repeated fixed32 listFixed32 = 1 [packed = false]; repeated fixed32 listFixed32Packed = 2 [packed = true]; repeated int32 listInt32 = 3 [packed = false]; repeated int32 listInt32Packed = 4 [packed = true]; repeated string listString = 5; -} +} \ No newline at end of file diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index 9471a899..8c79cca7 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -13,6 +13,7 @@ import org.slf4j.Logger private const val RPC_INTERNAL_PACKAGE_SUFFIX = "_rpc_internal" private const val MSG_INTERNAL_SUFFIX = "Internal" private const val PB_PKG = "kotlinx.rpc.grpc.pb" +private const val INTERNAL_RPC_API_ANNO = "kotlinx.rpc.internal.utils.InternalRpcApi" class ModelToKotlinCommonGenerator( private val model: Model, @@ -69,7 +70,7 @@ class ModelToKotlinCommonGenerator( this@generateInternalKotlinFile.packageName.safeFullName() .packageNameSuffixed(RPC_INTERNAL_PACKAGE_SUFFIX) - fileOptIns = listOf("ExperimentalRpcApi::class", "InternalRpcApi::class") + fileOptIns = listOf("ExperimentalRpcApi::class", "$INTERNAL_RPC_API_ANNO::class") dependencies.forEach { dependency -> importPackage(dependency.packageName.safeFullName()) @@ -77,9 +78,6 @@ class ModelToKotlinCommonGenerator( generateInternalDeclaredEntities(this@generateInternalKotlinFile) - import("kotlinx.rpc.internal.utils.*") - import("kotlinx.coroutines.flow.*") - additionalInternalImports.forEach { import(it) } @@ -148,7 +146,7 @@ class ModelToKotlinCommonGenerator( val internalClassName = declaration.internalClassName() clazz( name = internalClassName, - annotations = listOf("@kotlinx.rpc.internal.utils.InternalRpcApi"), + annotations = listOf("@$INTERNAL_RPC_API_ANNO"), declarationType = DeclarationType.Class, superTypes = listOf( declaration.name.safeFullName(), @@ -177,7 +175,7 @@ class ModelToKotlinCommonGenerator( code("override var $fieldDeclaration $value") if (field.presenceIdx != null) { scope("set(value) ") { - code("presenceMask.set(PresenceIndices.${field.name})") + code("presenceMask[PresenceIndices.${field.name}] = true") code("field = value") } } @@ -210,12 +208,13 @@ class ModelToKotlinCommonGenerator( private fun CodeGenerator.generateCodecObject(declaration: MessageDeclaration) { val msgFqName = declaration.name.safeFullName() - val downCastErrorStr = "The message is a custom implementation of ${msgFqName}. This is currently not allowed." + val downCastErrorStr = + "\${value::class.simpleName} implements ${msgFqName}, which is prohibited." val sourceFqName = "kotlinx.io.Source" val bufferFqName = "kotlinx.io.Buffer" scope("val CODEC = object : kotlinx.rpc.grpc.internal.MessageCodec<$msgFqName>") { - function("encode", modifiers = "override", args = "value: $msgFqName", returnType = "$sourceFqName") { - code("val msg = value as? ${declaration.internalClassFullName()} ?: error(\"$downCastErrorStr\")") + function("encode", modifiers = "override", args = "value: $msgFqName", returnType = sourceFqName) { + code("val msg = value as? ${declaration.internalClassFullName()} ?: error { \"$downCastErrorStr\" }") code("val buffer = $bufferFqName()") code("val encoder = $PB_PKG.WireEncoder(buffer)") code("msg.encodeWith(encoder)") @@ -223,7 +222,6 @@ class ModelToKotlinCommonGenerator( code("return buffer") } - function("decode", modifiers = "override", args = "stream: $sourceFqName", returnType = msgFqName) { scope("$PB_PKG.WireDecoder(stream as $bufferFqName).use") { code("return ${declaration.internalClassFullName()}.decodeWith(it)") @@ -372,7 +370,7 @@ class ModelToKotlinCommonGenerator( private fun FieldDeclaration.wireSizeCall(variable: String): String { - val sizeFunc = "WireSize.${type.decodeEncodeFuncName().replaceFirstChar { it.lowercase() }}($variable)" + val sizeFunc = "$PB_PKG.WireSize.${type.decodeEncodeFuncName().replaceFirstChar { it.lowercase() }}($variable)" return when (val fieldType = type) { is FieldType.IntegralType -> when { fieldType.wireType == WireType.FIXED32 -> "32"