Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions wire-grpc-client/api/wire-grpc-client.api
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,14 @@ public final class com/squareup/wire/GrpcHttpUrlKt {

public final class com/squareup/wire/GrpcMethod {
public fun <init> (Ljava/lang/String;Lcom/squareup/wire/ProtoAdapter;Lcom/squareup/wire/ProtoAdapter;)V
public fun <init> (Ljava/lang/String;Lcom/squareup/wire/ProtoAdapter;Lcom/squareup/wire/ProtoAdapter;Z)V
public fun <init> (Ljava/lang/String;Lcom/squareup/wire/ProtoAdapter;Lcom/squareup/wire/ProtoAdapter;ZZ)V
public synthetic fun <init> (Ljava/lang/String;Lcom/squareup/wire/ProtoAdapter;Lcom/squareup/wire/ProtoAdapter;ZZILkotlin/jvm/internal/DefaultConstructorMarker;)V
public final fun getPath ()Ljava/lang/String;
public final fun getRequestAdapter ()Lcom/squareup/wire/ProtoAdapter;
public final fun getRequestStreaming ()Z
public final fun getResponseAdapter ()Lcom/squareup/wire/ProtoAdapter;
public final fun getResponseStreaming ()Z
}

public abstract interface class com/squareup/wire/GrpcServerStreamingCall {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
*/
package com.squareup.wire

class GrpcMethod<S : Any, R : Any>(
import kotlin.jvm.JvmOverloads

class GrpcMethod<S : Any, R : Any> @JvmOverloads constructor(
val path: String,
val requestAdapter: ProtoAdapter<S>,
val responseAdapter: ProtoAdapter<R>,
val requestStreaming: Boolean = false,
val responseStreaming: Boolean = false,
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
package com.squareup.wire

import com.squareup.wire.internal.RealGrpcCall
import com.squareup.wire.internal.RealGrpcServerStreamingCall
import com.squareup.wire.internal.RealGrpcStreamingCall
import com.squareup.wire.internal.asGrpcClientStreamingCall
import com.squareup.wire.internal.asGrpcServerStreamingCall
import com.squareup.wire.internal.asGrpcStreamingCall
import java.util.concurrent.TimeUnit
import kotlin.reflect.KClass
import okhttp3.Call
Expand Down Expand Up @@ -183,9 +184,15 @@ internal class WireGrpcClient internal constructor(
) : GrpcClient() {
override fun <S : Any, R : Any> newCall(method: GrpcMethod<S, R>): GrpcCall<S, R> = RealGrpcCall(this, method)

override fun <S : Any, R : Any> newStreamingCall(method: GrpcMethod<S, R>): GrpcStreamingCall<S, R> = RealGrpcStreamingCall(this, method)
override fun <S : Any, R : Any> newStreamingCall(method: GrpcMethod<S, R>): GrpcStreamingCall<S, R> {
return if (!method.requestStreaming && method.responseStreaming) {
RealGrpcServerStreamingCall(this, method).asGrpcStreamingCall()
} else {
RealGrpcStreamingCall(this, method)
}
}

override fun <S : Any, R : Any> newClientStreamingCall(method: GrpcMethod<S, R>): GrpcClientStreamingCall<S, R> = RealGrpcStreamingCall(this, method).asGrpcClientStreamingCall()

override fun <S : Any, R : Any> newServerStreamingCall(method: GrpcMethod<S, R>): GrpcServerStreamingCall<S, R> = RealGrpcStreamingCall(this, method).asGrpcServerStreamingCall()
override fun <S : Any, R : Any> newServerStreamingCall(method: GrpcMethod<S, R>): GrpcServerStreamingCall<S, R> = RealGrpcServerStreamingCall(this, method)
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import okio.IOException
* * Complete: enqueued when the stream completes normally.
*/
internal class BlockingMessageSource<R : Any>(
val grpcCall: RealGrpcStreamingCall<*, R>,
val onResponseMetadata: (Map<String, String>) -> Unit,
val responseAdapter: ProtoAdapter<R>,
val call: Call,
) : MessageSource<R> {
Expand Down Expand Up @@ -66,7 +66,7 @@ internal class BlockingMessageSource<R : Any>(

override fun onResponse(call: Call, response: Response) {
try {
grpcCall.responseMetadata = response.headers.toMap()
onResponseMetadata(response.headers.toMap())
response.use {
response.messageSource(responseAdapter).use { reader ->
while (true) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,122 @@ package com.squareup.wire.internal
import com.squareup.wire.GrpcMethod
import com.squareup.wire.GrpcServerStreamingCall
import com.squareup.wire.GrpcStreamingCall
import com.squareup.wire.MessageSink
import com.squareup.wire.MessageSource
import com.squareup.wire.WireGrpcClient
import java.util.concurrent.TimeUnit
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.GlobalScope
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.channels.SendChannel
import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
import okio.ForwardingTimeout
import okio.IOException
import okio.Timeout

/**
* A [GrpcServerStreamingCall] that sends a single non-duplex request and reads a streaming
* response. Using a non-duplex request body ensures the complete request (including END_STREAM) is
* sent to the server before responses are read, avoiding delays on servers that wait for the
* client's half-close before starting to stream responses.
*/
internal class RealGrpcServerStreamingCall<S : Any, R : Any>(
private val grpcClient: WireGrpcClient,
override val method: GrpcMethod<S, R>,
) : GrpcServerStreamingCall<S, R> {

private var call: okhttp3.Call? = null
private var canceled = false

override val timeout: Timeout = ForwardingTimeout(Timeout())

init {
timeout.clearTimeout()
timeout.clearDeadline()
}

override var requestMetadata: Map<String, String> = mapOf()

override var responseMetadata: Map<String, String>? = null
internal set

override fun cancel() {
canceled = true
call?.cancel()
}

override fun isCanceled(): Boolean = canceled || call?.isCanceled() == true

override fun isExecuted(): Boolean = call?.isExecuted() ?: false

override fun clone(): GrpcServerStreamingCall<S, R> {
val result = RealGrpcServerStreamingCall(grpcClient, method)
val oldTimeout = this.timeout
result.timeout.also { newTimeout ->
newTimeout.timeout(oldTimeout.timeoutNanos(), TimeUnit.NANOSECONDS)
if (oldTimeout.hasDeadline()) {
newTimeout.deadlineNanoTime(oldTimeout.deadlineNanoTime())
} else {
newTimeout.clearDeadline()
}
}
result.requestMetadata += this.requestMetadata
return result
}

override suspend fun executeIn(scope: CoroutineScope, request: S): ReceiveChannel<R> {
val responseChannel = Channel<R>(1)
val call = initCall(request)

responseChannel.invokeOnClose { cause ->
if (cause != null) {
call.cancel()
}
}

call.enqueue(
responseChannel.readFromResponseBodyCallback(
onResponseMetadata = { this.responseMetadata = it },
responseAdapter = method.responseAdapter,
),
)

return responseChannel
}

override fun executeBlocking(request: S): MessageSource<R> {
val call = initCall(request)
val messageSource = BlockingMessageSource(
onResponseMetadata = { this.responseMetadata = it },
responseAdapter = method.responseAdapter,
call = call,
)
call.enqueue(messageSource.readFromResponseBodyCallback())
return messageSource
}

private fun initCall(request: S): okhttp3.Call {
check(this.call == null) { "already executed" }
val requestBody = newRequestBody(
minMessageToCompress = grpcClient.minMessageToCompress,
requestAdapter = method.requestAdapter,
onlyMessage = request,
)
val result = grpcClient.newCall(method, requestMetadata, requestBody, timeout)
this.call = result
if (canceled) result.cancel()
(timeout as ForwardingTimeout).setDelegate(result.timeout())
return result
}
}

/**
* Wraps a [GrpcStreamingCall] as a [GrpcServerStreamingCall]. Used for test doubles created via
* [com.squareup.wire.GrpcServerStreamingCall] factory functions in GrpcCalls.
*/
internal class GrpcStreamingCallServerStreamingAdapter<S : Any, R : Any>(
private val callDelegate: GrpcStreamingCall<S, R>,
override val method: GrpcMethod<S, R>,
) : GrpcServerStreamingCall<S, R> {
Expand All @@ -48,7 +158,7 @@ internal class RealGrpcServerStreamingCall<S : Any, R : Any>(

override fun isExecuted() = callDelegate.isExecuted()

override fun clone() = RealGrpcServerStreamingCall(callDelegate.clone(), method)
override fun clone() = GrpcStreamingCallServerStreamingAdapter(callDelegate.clone(), method)

override suspend fun executeIn(scope: CoroutineScope, request: S): ReceiveChannel<R> {
val (sendChannel, receiveChannel) = callDelegate.executeIn(scope)
Expand All @@ -67,4 +177,129 @@ internal class RealGrpcServerStreamingCall<S : Any, R : Any>(
}
}

internal fun <S : Any, R : Any> GrpcStreamingCall<S, R>.asGrpcServerStreamingCall() = RealGrpcServerStreamingCall(this, method)
internal fun <S : Any, R : Any> GrpcStreamingCall<S, R>.asGrpcServerStreamingCall() = GrpcStreamingCallServerStreamingAdapter(this, method)

/**
* Wraps a [GrpcServerStreamingCall] as the legacy [GrpcStreamingCall] API. This is used by
* generated clients when explicit streaming call types are disabled.
*/
internal class GrpcServerStreamingCallStreamingAdapter<S : Any, R : Any>(
private val callDelegate: GrpcServerStreamingCall<S, R>,
override val method: GrpcMethod<S, R>,
) : GrpcStreamingCall<S, R> {
private var executed = false

override val timeout: Timeout
get() = callDelegate.timeout

override var requestMetadata: Map<String, String>
get() = callDelegate.requestMetadata
set(value) {
callDelegate.requestMetadata = value
}

override val responseMetadata: Map<String, String>?
get() = callDelegate.responseMetadata

override fun cancel() {
callDelegate.cancel()
}

override fun isCanceled() = callDelegate.isCanceled()

@Suppress("OPT_IN_USAGE", "OVERRIDE_DEPRECATION")
override fun execute(): Pair<SendChannel<S>, ReceiveChannel<R>> {
return executeIn(GlobalScope)
}

override fun executeIn(scope: CoroutineScope): Pair<SendChannel<S>, ReceiveChannel<R>> {
return executeWithChannels(scope)
}

@Suppress("OPT_IN_USAGE")
override fun executeBlocking(): Pair<MessageSink<S>, MessageSource<R>> {
val (requestChannel, responseChannel) = executeWithChannels(GlobalScope)
return requestChannel.toMessageSink() to responseChannel.toMessageSource()
}

override fun isExecuted() = executed || callDelegate.isExecuted()

override fun clone() = GrpcServerStreamingCallStreamingAdapter(callDelegate.clone(), method)

private fun executeWithChannels(scope: CoroutineScope): Pair<Channel<S>, Channel<R>> {
check(!executed) { "already executed" }
executed = true

val requestChannel = Channel<S>(1)
val responseChannel = Channel<R>(1)
var delegateResponseChannel: ReceiveChannel<R>? = null

responseChannel.invokeOnClose { cause ->
if (cause != null) {
requestChannel.cancel()
delegateResponseChannel?.cancel()
callDelegate.cancel()
}
}

scope.launch {
try {
val requestResult = requestChannel.receiveCatching()
requestResult.exceptionOrNull()?.let { throw it }
val request = requestResult.getOrNull()
?: throw ProtocolException("expected 1 message but got none")
requestChannel.close()
val responses = callDelegate.executeIn(scope, request)
delegateResponseChannel = responses
for (response in responses) {
responseChannel.send(response)
}
responseChannel.close()
} catch (e: Throwable) {
responseChannel.close(e)
}
}

return requestChannel to responseChannel
}
}

internal fun <S : Any, R : Any> GrpcServerStreamingCall<S, R>.asGrpcStreamingCall() = GrpcServerStreamingCallStreamingAdapter(this, method)

private fun <E : Any> Channel<E>.toMessageSource() = object : MessageSource<E> {
override fun read(): E? = runBlocking {
try {
val result = receiveCatching()
result.exceptionOrNull()?.let { throw it }
result.getOrNull()
} catch (e: Throwable) {
throw e.toIOException()
}
}

override fun close() {
cancel()
}
}

private fun <E : Any> Channel<E>.toMessageSink() = object : MessageSink<E> {
override fun write(message: E) {
runBlocking {
try {
send(message)
} catch (e: Throwable) {
throw e.toIOException()
}
}
}

override fun cancel() {
this@toMessageSink.cancel()
}

override fun close() {
this@toMessageSink.close()
}
}

private fun Throwable.toIOException() = this as? IOException ?: IOException(this)
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,23 @@ internal class RealGrpcStreamingCall<S : Any, R : Any>(
callForCancel = call,
)
}
call.enqueue(responseChannel.readFromResponseBodyCallback(this, method.responseAdapter))
call.enqueue(
responseChannel.readFromResponseBodyCallback(
onResponseMetadata = { this.responseMetadata = it },
responseAdapter = method.responseAdapter,
),
)

return requestChannel to responseChannel
}

override fun executeBlocking(): Pair<MessageSink<S>, MessageSource<R>> {
val call = initCall()
val messageSource = BlockingMessageSource(this, method.responseAdapter, call)
val messageSource = BlockingMessageSource(
onResponseMetadata = { this.responseMetadata = it },
responseAdapter = method.responseAdapter,
call = call,
)
val messageSink = requestBody.messageSink(
minMessageToCompress = grpcClient.minMessageToCompress,
requestAdapter = method.requestAdapter,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,15 @@ internal fun <S : Any> PipeDuplexRequestBody.messageSink(
internal fun <R : Any> SendChannel<R>.readFromResponseBodyCallback(
grpcCall: RealGrpcStreamingCall<*, R>,
responseAdapter: ProtoAdapter<R>,
): Callback = readFromResponseBodyCallback(
onResponseMetadata = { grpcCall.responseMetadata = it },
responseAdapter = responseAdapter,
)

/** Sends the response messages to the channel. */
internal fun <R : Any> SendChannel<R>.readFromResponseBodyCallback(
onResponseMetadata: (Map<String, String>) -> Unit,
responseAdapter: ProtoAdapter<R>,
): Callback {
return object : Callback {
override fun onFailure(call: Call, e: IOException) {
Expand All @@ -91,7 +100,7 @@ internal fun <R : Any> SendChannel<R>.readFromResponseBodyCallback(
}

override fun onResponse(call: Call, response: Response) {
grpcCall.responseMetadata = response.headers.toMap()
onResponseMetadata(response.headers.toMap())
runBlocking {
response.use {
val messageSource = try {
Expand Down
Loading
Loading