Skip to content

[CELEBORN-770][FLINK] Convert BacklogAnnouncement, BufferStreamEnd, ReadAddCredit to PB #1905

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
[CELEBORN-770][FLINK] Convert BacklogAnnouncement, BufferStreamEnd, R…
…eadAddCredit to PB
  • Loading branch information
SteNicholas committed Sep 22, 2023
commit 903dce48a9e4b0dba3088efbfc8e1f9a5bd0d271
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@

import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
import org.apache.celeborn.common.network.protocol.BufferStreamEnd;
import org.apache.celeborn.common.network.protocol.ReadAddCredit;
import org.apache.celeborn.common.network.protocol.RequestMessage;
import org.apache.celeborn.common.network.protocol.TransportableError;
import org.apache.celeborn.common.network.util.NettyUtils;
import org.apache.celeborn.common.protocol.PbReadAddCredit;
import org.apache.celeborn.plugin.flink.buffer.CreditListener;
import org.apache.celeborn.plugin.flink.buffer.TransferBufferPool;
import org.apache.celeborn.plugin.flink.protocol.ReadData;
Expand Down Expand Up @@ -115,8 +115,11 @@ public boolean isOpened() {

public void notifyAvailableCredits(int numCredits) {
if (!closed) {
ReadAddCredit addCredit = new ReadAddCredit(bufferStream.getStreamId(), numCredits);
bufferStream.addCredit(addCredit);
bufferStream.addCredit(
PbReadAddCredit.newBuilder()
.setStreamId(bufferStream.getStreamId())
.setCredit(numCredits)
.build());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

package org.apache.celeborn.plugin.flink.network;

import static org.apache.celeborn.common.protocol.MessageType.BACKLOG_ANNOUNCEMENT_VALUE;
import static org.apache.celeborn.common.protocol.MessageType.BUFFER_STREAM_END_VALUE;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
Expand All @@ -28,6 +32,7 @@
import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
import org.apache.celeborn.common.network.protocol.BufferStreamEnd;
import org.apache.celeborn.common.network.protocol.RequestMessage;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.network.protocol.TransportableError;
import org.apache.celeborn.common.network.server.BaseMessageHandler;
import org.apache.celeborn.common.util.JavaUtils;
Expand Down Expand Up @@ -66,32 +71,43 @@ private void processMessageInternal(long streamId, RequestMessage msg) {

@Override
public void receive(TransportClient client, RequestMessage msg) {
long streamId = 0;
switch (msg.type()) {
case READ_DATA:
ReadData readData = (ReadData) msg;
streamId = readData.getStreamId();
processMessageInternal(streamId, readData);
processMessageInternal(readData.getStreamId(), readData);
break;
case BACKLOG_ANNOUNCEMENT:
BacklogAnnouncement backlogAnnouncement = (BacklogAnnouncement) msg;
streamId = backlogAnnouncement.getStreamId();
processMessageInternal(streamId, backlogAnnouncement);
processMessageInternal(backlogAnnouncement.getStreamId(), backlogAnnouncement);
break;
case TRANSPORTABLE_ERROR:
TransportableError transportableError = ((TransportableError) msg);
streamId = transportableError.getStreamId();
logger.warn(
"Received TransportableError from worker {} with content {}",
client.getSocketAddress().toString(),
transportableError.getErrorMessage());
processMessageInternal(streamId, transportableError);
processMessageInternal(transportableError.getStreamId(), transportableError);
break;
case BUFFER_STREAM_END:
BufferStreamEnd streamEnd = (BufferStreamEnd) msg;
logger.debug("Received streamend for {}", streamEnd.getStreamId());
processMessageInternal(streamEnd.getStreamId(), streamEnd);
break;
case RPC_REQUEST:
try {
TransportMessage transportMessage =
TransportMessage.fromByteBuffer(msg.body().nioByteBuffer());
switch (transportMessage.getMessageTypeValue()) {
case BACKLOG_ANNOUNCEMENT_VALUE:
receive(client, BacklogAnnouncement.fromProto(transportMessage.getParsedPayload()));
break;
case BUFFER_STREAM_END_VALUE:
receive(client, BufferStreamEnd.fromProto(transportMessage.getParsedPayload()));
break;
}
} catch (IOException e) {
logger.warn("Failed to process RpcRequest message {}. ", msg, e);
}
break;
case ONE_WAY_MESSAGE:
// ignore it.
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@

import org.apache.celeborn.common.network.client.RpcResponseCallback;
import org.apache.celeborn.common.network.client.TransportClient;
import org.apache.celeborn.common.network.protocol.*;
import org.apache.celeborn.common.network.protocol.RequestMessage;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.network.protocol.TransportableError;
import org.apache.celeborn.common.network.util.NettyUtils;
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
import org.apache.celeborn.common.protocol.PbOpenStream;
import org.apache.celeborn.common.protocol.PbReadAddCredit;
import org.apache.celeborn.common.protocol.PbStreamHandler;
import org.apache.celeborn.plugin.flink.network.FlinkTransportClientFactory;

Expand Down Expand Up @@ -82,21 +86,25 @@ public void open(
moveToNextPartitionIfPossible(0);
}

public void addCredit(ReadAddCredit addCredit) {
this.client
.getChannel()
.writeAndFlush(addCredit)
.addListener(
future -> {
if (future.isSuccess()) {
// Send ReadAddCredit do not expect response.
} else {
logger.warn(
"Send ReadAddCredit to {} failed, detail {}",
this.client.getSocketAddress().toString(),
future.cause());
}
});
public void addCredit(PbReadAddCredit pbReadAddCredit) {
this.client.sendRpc(
new TransportMessage(MessageType.READ_ADD_CREDIT, pbReadAddCredit.toByteArray())
.toByteBuffer(),
new RpcResponseCallback() {

@Override
public void onSuccess(ByteBuffer response) {
// Send PbReadAddCredit do not expect response.
}

@Override
public void onFailure(Throwable e) {
logger.warn(
"Send PbReadAddCredit to {} failed, detail {}",
NettyUtils.getRemoteAddress(client.getChannel()),
e.getCause());
}
});
}

public static CelebornBufferStream empty() {
Expand Down Expand Up @@ -127,7 +135,11 @@ public static CelebornBufferStream create(

private void closeStream(long streamId) {
if (client != null && client.isActive()) {
client.getChannel().writeAndFlush(new BufferStreamEnd(streamId));
client.sendRpc(
new TransportMessage(
MessageType.BUFFER_STREAM_END,
PbBufferStreamEnd.newBuilder().setStreamId(streamId).build().toByteArray())
.toByteBuffer());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.celeborn.plugin.flink.network;

import static org.apache.celeborn.common.network.client.TransportClient.requestId;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
Expand All @@ -31,9 +33,13 @@
import org.junit.Test;
import org.mockito.Mockito;

import org.apache.celeborn.common.network.protocol.BacklogAnnouncement;
import org.apache.celeborn.common.network.buffer.NioManagedBuffer;
import org.apache.celeborn.common.network.protocol.Message;
import org.apache.celeborn.common.network.protocol.ReadData;
import org.apache.celeborn.common.network.protocol.RpcRequest;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PbBacklogAnnouncement;
import org.apache.celeborn.common.util.JavaUtils;

public class TransportFrameDecoderWithBufferSupplierSuiteJ {
Expand All @@ -57,10 +63,10 @@ public void testDropUnusedBytes() throws IOException {
new TransportFrameDecoderWithBufferSupplier(supplier);
ChannelHandlerContext context = Mockito.mock(ChannelHandlerContext.class);

BacklogAnnouncement announcement = new BacklogAnnouncement(0, 0);
RpcRequest announcement = createBacklogAnnouncement(0, 0);
ReadData unUsedReadData = new ReadData(1, generateData(1024));
ReadData readData = new ReadData(2, generateData(1024));
BacklogAnnouncement announcement1 = new BacklogAnnouncement(0, 0);
RpcRequest announcement1 = createBacklogAnnouncement(0, 0);
ReadData unUsedReadData1 = new ReadData(1, generateData(1024));
ReadData readData1 = new ReadData(2, generateData(8));

Expand Down Expand Up @@ -102,6 +108,20 @@ public void testDropUnusedBytes() throws IOException {
Assert.assertEquals(buffers.size(), 6);
}

public RpcRequest createBacklogAnnouncement(long streamId, int backlog) {
return new RpcRequest(
requestId(),
new NioManagedBuffer(
new TransportMessage(
MessageType.BACKLOG_ANNOUNCEMENT,
PbBacklogAnnouncement.newBuilder()
.setStreamId(streamId)
.setBacklog(backlog)
.build()
.toByteArray())
.toByteBuffer()));
}

public ByteBuf encodeMessage(Message in, ByteBuf byteBuf) throws IOException {
byteBuf.writeInt(in.encodedLength());
in.type().encode(byteBuf);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ public long sendRpc(ByteBuffer message, RpcResponseCallback callback) {
return requestId;
}

/**
* Sends an opaque message to the RpcHandler on the server-side.
*
* @param message The message to send.
* @return The RPC's id.
*/
public long sendRpc(ByteBuffer message) {
if (logger.isTraceEnabled()) {
logger.trace("Sending RPC to {}", NettyUtils.getRemoteAddress(channel));
}

long requestId = requestId();
channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message)));
return requestId;
}

public ChannelFuture pushData(
PushData pushData, long pushDataTimeout, RpcResponseCallback callback) {
return pushData(pushData, pushDataTimeout, callback, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import io.netty.buffer.ByteBuf;

import org.apache.celeborn.common.protocol.PbBacklogAnnouncement;

// This RPC is sent to flink plugin to tell flink client to be ready for buffers.
public class BacklogAnnouncement extends RequestMessage {
private long streamId;
Expand Down Expand Up @@ -60,4 +62,8 @@ public long getStreamId() {
public int getBacklog() {
return backlog;
}

public static BacklogAnnouncement fromProto(PbBacklogAnnouncement pb) {
return new BacklogAnnouncement(pb.getStreamId(), pb.getBacklog());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import io.netty.buffer.ByteBuf;

import org.apache.celeborn.common.protocol.PbBufferStreamEnd;

public class BufferStreamEnd extends RequestMessage {
private long streamId;

Expand Down Expand Up @@ -49,4 +51,8 @@ public static Message decode(ByteBuf buffer) {
public long getStreamId() {
return streamId;
}

public static BufferStreamEnd fromProto(PbBufferStreamEnd pb) {
return new BufferStreamEnd(pb.getStreamId());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import io.netty.buffer.ByteBuf;

@Deprecated
public class ReadAddCredit extends RequestMessage {
private long streamId;
private int credit;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

package org.apache.celeborn.common.network.protocol;

import static org.apache.celeborn.common.protocol.MessageType.BACKLOG_ANNOUNCEMENT_VALUE;
import static org.apache.celeborn.common.protocol.MessageType.BUFFER_STREAM_END_VALUE;
import static org.apache.celeborn.common.protocol.MessageType.OPEN_STREAM_VALUE;
import static org.apache.celeborn.common.protocol.MessageType.PUSH_DATA_HAND_SHAKE_VALUE;
import static org.apache.celeborn.common.protocol.MessageType.READ_ADD_CREDIT_VALUE;
import static org.apache.celeborn.common.protocol.MessageType.REGION_FINISH_VALUE;
import static org.apache.celeborn.common.protocol.MessageType.REGION_START_VALUE;
import static org.apache.celeborn.common.protocol.MessageType.STREAM_HANDLER_VALUE;
Expand All @@ -33,8 +36,11 @@

import org.apache.celeborn.common.exception.CelebornIOException;
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PbBacklogAnnouncement;
import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
import org.apache.celeborn.common.protocol.PbOpenStream;
import org.apache.celeborn.common.protocol.PbPushDataHandShake;
import org.apache.celeborn.common.protocol.PbReadAddCredit;
import org.apache.celeborn.common.protocol.PbRegionFinish;
import org.apache.celeborn.common.protocol.PbRegionStart;
import org.apache.celeborn.common.protocol.PbStreamHandler;
Expand Down Expand Up @@ -76,6 +82,12 @@ public <T extends GeneratedMessageV3> T getParsedPayload() throws InvalidProtoco
return (T) PbRegionStart.parseFrom(payload);
case REGION_FINISH_VALUE:
return (T) PbRegionFinish.parseFrom(payload);
case BACKLOG_ANNOUNCEMENT_VALUE:
return (T) PbBacklogAnnouncement.parseFrom(payload);
case BUFFER_STREAM_END_VALUE:
return (T) PbBufferStreamEnd.parseFrom(payload);
case READ_ADD_CREDIT_VALUE:
return (T) PbReadAddCredit.parseFrom(payload);
default:
logger.error("Unexpected type {}", type);
}
Expand Down
21 changes: 19 additions & 2 deletions common/src/main/proto/TransportMessages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ enum MessageType {
PUSH_DATA_HAND_SHAKE = 56;
REGION_START = 57;
REGION_FINISH = 58;
BACKLOG_ANNOUNCEMENT = 59;
BUFFER_STREAM_END = 60;
READ_ADD_CREDIT = 61;
}

message PbStorageInfo {
Expand Down Expand Up @@ -498,9 +501,9 @@ message PbOpenStream {
}

message PbStreamHandler {
int64 streamId = 1 ;
int64 streamId = 1;
int32 numChunks = 2;
repeated int64 chunkOffsets = 3 ;
repeated int64 chunkOffsets = 3;
string fullPath = 4;
}

Expand Down Expand Up @@ -528,3 +531,17 @@ message PbRegionFinish {
string partitionUniqueId = 3;
int32 attemptId = 4;
}

message PbBacklogAnnouncement {
int64 streamId = 1;
int32 backlog = 2;
}

message PbBufferStreamEnd {
int64 streamId = 1;
}

message PbReadAddCredit {
int64 streamId = 1;
int32 credit = 2;
}
Loading