Skip to content

Set max memory of SortBasedPusher based off Spark configs #3203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Set max memory of SortBasedPusher based off Spark configs
  • Loading branch information
helenweng-stripe committed Apr 7, 2025
commit ffb86930541fb1a8cc1a8c6ad27d44e10990d887
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
import java.util.concurrent.atomic.LongAdder;
import java.util.function.Consumer;

import com.google.common.annotations.VisibleForTesting;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.SparkOutOfMemoryError;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.memory.TooLargePageException;
import org.apache.spark.internal.config.package$;
import org.apache.spark.memory.*;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.UnsafeAlignedOffset;
import org.apache.spark.unsafe.array.LongArray;
Expand All @@ -49,11 +49,22 @@ class MemoryThresholdManager {
private final long maxMemoryThresholdInBytes;
private final double smallPushTolerateFactor;

MemoryThresholdManager(double maxMemoryFactor, double smallPushTolerateFactor) {
this.maxMemoryThresholdInBytes = (long) (Runtime.getRuntime().maxMemory() * maxMemoryFactor);
MemoryThresholdManager(
double maxMemoryFactor, double smallPushTolerateFactor, Long maxTaskMemory) {
this.maxMemoryThresholdInBytes =
(long)
((maxTaskMemory <= 0L ? Runtime.getRuntime().maxMemory() : maxTaskMemory)
* maxMemoryFactor);

logger.info("setting max memory threshold to " + String.valueOf(maxMemoryThresholdInBytes));
this.smallPushTolerateFactor = smallPushTolerateFactor;
}

@VisibleForTesting
protected long getMaxMemoryThresholdInBytes() {
return maxMemoryThresholdInBytes;
}

private boolean shouldGrow() {
boolean enoughSpace = pushSortMemoryThreshold <= maxMemoryThresholdInBytes;
double expectedPushSize = Long.MAX_VALUE;
Expand Down Expand Up @@ -197,7 +208,10 @@ public SortBasedPusher(
this.pushSortMemoryThreshold = pushSortMemoryThreshold;

this.memoryThresholdManager =
new MemoryThresholdManager(maxMemoryFactor, conf.clientPushSortSmallPushTolerateFactor());
new MemoryThresholdManager(
maxMemoryFactor,
conf.clientPushSortSmallPushTolerateFactor(),
conf.clientPushSortMaxMemoryBytes());

int initialSize = Math.min((int) pushSortMemoryThreshold / 8, 1024 * 1024);
this.inMemSorter = new ShuffleInMemorySorter(this, initialSize);
Expand Down Expand Up @@ -518,4 +532,48 @@ public void close(boolean throwTaskKilledOnInterruption) throws IOException {
public long getUsed() {
return super.getUsed();
}

//

/**
* Calculates max memory conf based on SparkConf settings. Follows logic in Spark
* UnifiedMemoryManager.getMaxMemory:
* github.com/apache/spark/blob/branch-3.3/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala#L213
*/
public static CelebornConf setMemoryConfs(
SparkConf sparkConf, CelebornConf celebornConf, int cores, MemoryMode memoryMode) {
try {
if (!celebornConf.clientPushSortCalculateMaxMemoryBytes()) {
return celebornConf;
}

// set max task memory conf based on Spark conf
if (celebornConf.clientPushSortMaxMemoryBytes() <= 0L) {
double memoryStorageFraction =
sparkConf.getDouble(package$.MODULE$.MEMORY_FRACTION().key(), 0.6);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For spark-2.4 UT

Error:  /home/runner/work/celeborn/celeborn/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java:553:48:  error: cannot find symbol

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you can use the string value directly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get the values of maxOffHeapMemory/maxHeapMemory from the MemoryManager through reflection?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the feedback, I will try using reflection and re-submit.

if (memoryMode == MemoryMode.ON_HEAP
&& sparkConf.contains(package$.MODULE$.EXECUTOR_MEMORY())) {
long maxExecutorMemory =
sparkConf.getSizeAsBytes(package$.MODULE$.EXECUTOR_MEMORY().key());
long reservedSystemMemory =
sparkConf.getBoolean("spark.testing", false) ? 0 : 300 * 1024 * 1024;
long totalAvailOnHeapMem =
Math.round(
(maxExecutorMemory - reservedSystemMemory) * memoryStorageFraction / cores);
celebornConf.set(CelebornConf.CLIENT_PUSH_SORT_MAX_MEMORY_BYTES(), totalAvailOnHeapMem);
} else if (memoryMode == MemoryMode.OFF_HEAP
&& sparkConf.contains(package$.MODULE$.MEMORY_OFFHEAP_SIZE())) {
long maxOffHeapMemory =
sparkConf.getSizeAsBytes(package$.MODULE$.MEMORY_OFFHEAP_SIZE().key());
long totalAvailOffHeapMem =
Math.round(maxOffHeapMemory * (memoryStorageFraction) / cores);
celebornConf.set(CelebornConf.CLIENT_PUSH_SORT_MAX_MEMORY_BYTES(), totalAvailOffHeapMem);
}
}
} catch (Exception e) {
logger.error(
"SortBasedPusher.setMemoryConfs failed to set memory confs, threw exception:", e);
}
return celebornConf;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext$;
import org.apache.spark.TaskContextImpl;
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.memory.UnifiedMemoryManager;
import org.apache.spark.sql.catalyst.InternalRow;
Expand All @@ -39,6 +40,7 @@
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.mockito.Mockito;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -132,6 +134,53 @@ public void testMemoryUsage() throws Exception {
assertEquals(taskContext.taskMetrics().memoryBytesSpilled(), 2097152);
}

@Test
public void testSetMemoryConfsDisabled() {
// celeborn.client.spark.push.sort.memory.calculateMaxMemoryBytes disabled by default
CelebornConf conf = new CelebornConf();
assertEquals(false, conf.clientPushSortCalculateMaxMemoryBytes());
SortBasedPusher.setMemoryConfs(sparkConf, conf, 1, MemoryMode.ON_HEAP);
assertEquals(0L, conf.clientPushSortMaxMemoryBytes());
}

@Test
public void testSetMemoryConfsOnHeap() {
// celeborn.client.spark.push.sort.memory.calculateMaxMemoryBytes disabled by default
CelebornConf conf =
new CelebornConf().set(CelebornConf.CLIENT_PUSH_SORT_CALCULATE_MAX_MEMORY_BYTES(), true);
SparkConf sparkConf =
new SparkConf(false)
.set("spark.executor.memory", Integer.toString(1200 * 1024 * 1024))
.set("spark.memory.offHeap.size", Integer.toString(600 * 1024 * 1024));
SortBasedPusher.setMemoryConfs(sparkConf, conf, 4, MemoryMode.ON_HEAP);
// (1200m - 300m reserved) * .6 memory fraction / 4 cores = 135
assertEquals(135 * 1024 * 1024, conf.clientPushSortMaxMemoryBytes());
}

@Test
public void testSetMemoryConfsOffHeap() {
// celeborn.client.spark.push.sort.memory.calculateMaxMemoryBytes disabled by default
CelebornConf conf =
new CelebornConf().set(CelebornConf.CLIENT_PUSH_SORT_CALCULATE_MAX_MEMORY_BYTES(), true);
SparkConf sparkConf =
new SparkConf(false)
.set("spark.executor.memory", Integer.toString(1200 * 1024 * 1024))
.set("spark.memory.offHeap.size", Integer.toString(600 * 1024 * 1024))
.set("spark.memory.fraction", Double.toString(0.8));
SortBasedPusher.setMemoryConfs(sparkConf, conf, 3, MemoryMode.OFF_HEAP);
// (1200m - 300m reserved) * .9 memory fraction / 4 cores = 160
assertEquals(160 * 1024 * 1024, conf.clientPushSortMaxMemoryBytes());
}

@Test
public void testSetMemoryConfsException() {
SparkConf sparkConf = new SparkConf(false);
CelebornConf celebornConf = Mockito.mock(CelebornConf.class);
Mockito.when(celebornConf.clientPushSortCalculateMaxMemoryBytes())
.thenThrow(new RuntimeException("Test exception"));
SortBasedPusher.setMemoryConfs(sparkConf, celebornConf, 1, MemoryMode.ON_HEAP);
}

private static UnsafeRow genUnsafeRow(int size) {
ListBuffer<Object> values = new ListBuffer<>();
byte[] bytes = new byte[size];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,8 @@ public <K, V> ShuffleWriter<K, V> getWriter(
h.dependency(),
h.numMappers(),
context,
celebornConf,
SortBasedPusher.setMemoryConfs(
conf, celebornConf, cores, context.taskMemoryManager().getTungstenMemoryMode()),
shuffleClient,
metrics,
SendBufferPool.get(cores, sendBufferPoolCheckInterval, sendBufferPoolExpireTimeout));
Expand Down
Loading
Loading