Skip to content
This repository was archived by the owner on Apr 17, 2024. It is now read-only.

Commit f84a560

Browse files
authored
Merge pull request #70 from xuechendi/wip_memory
[Scala] Optimize Memory management and Track
2 parents 26c766a + 8030c0b commit f84a560

File tree

9 files changed

+134
-24
lines changed

9 files changed

+134
-24
lines changed

core/src/main/java/org/apache/spark/network/pmof/ShuffleBuffer.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import io.netty.buffer.ByteBuf;
55
import io.netty.buffer.PooledByteBufAllocator;
66
import org.apache.spark.network.buffer.ManagedBuffer;
7+
import org.apache.spark.storage.pmof.NettyByteBufferPool;
78
import org.apache.spark.unsafe.memory.MemoryBlock;
89
import org.apache.spark.unsafe.memory.UnsafeMemoryAllocator;
910
import sun.nio.ch.FileChannelImpl;
@@ -54,7 +55,7 @@ public ShuffleBuffer(long length, EqService service, boolean supportNettyBuffer)
5455
this.byteBuffer = convertToByteBuffer();
5556
this.byteBuffer.limit((int)length);
5657
} else {
57-
this.buf = PooledByteBufAllocator.DEFAULT.directBuffer((int) this.length, (int)this.length);
58+
this.buf = NettyByteBufferPool.allocateNewBuffer((int) this.length);
5859
this.address = this.buf.memoryAddress();
5960
this.byteBuffer = this.buf.nioBuffer(0, (int)length);
6061
}
@@ -135,7 +136,7 @@ public ManagedBuffer close() {
135136
}
136137
} else {
137138
if (this.supportNettyBuffer) {
138-
this.buf.release();
139+
NettyByteBufferPool.releaseBuffer(this.buf);
139140
} else {
140141
unsafeAlloc.free(memoryBlock);
141142
}

core/src/main/java/org/apache/spark/storage/pmof/PmemBuffer.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,15 @@ public class PmemBuffer {
1515
private native long nativeDeletePmemBuffer(long pmBuffer);
1616

1717
private boolean closed = false;
18+
private long len = 0;
1819
long pmBuffer;
1920
PmemBuffer() {
2021
pmBuffer = nativeNewPmemBuffer();
2122
}
2223

2324
PmemBuffer(long len) {
25+
this.len = len;
26+
NettyByteBufferPool.unpooledInc(len);
2427
pmBuffer = nativeNewPmemBufferBySize(len);
2528
}
2629

@@ -48,6 +51,7 @@ void put(byte[] bytes, int off, int len) {
4851
}
4952

5053
void clean() {
54+
NettyByteBufferPool.unpooledDec(len);
5155
nativeCleanPmemBuffer(pmBuffer);
5256
}
5357

@@ -60,9 +64,10 @@ long getDirectAddr() {
6064
}
6165

6266
synchronized void close() {
63-
if (!closed) {
64-
nativeDeletePmemBuffer(pmBuffer);
65-
closed = true;
66-
}
67+
if (!closed) {
68+
clean();
69+
nativeDeletePmemBuffer(pmBuffer);
70+
closed = true;
71+
}
6772
}
6873
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package org.apache.spark.storage.pmof
2+
3+
import java.util.concurrent.atomic.AtomicLong
4+
import io.netty.buffer.{ByteBuf, PooledByteBufAllocator, UnpooledByteBufAllocator}
5+
import scala.collection.mutable.Stack
6+
import java.lang.RuntimeException
7+
import org.apache.spark.internal.Logging
8+
9+
object NettyByteBufferPool extends Logging {
10+
private val allocatedBufRenCnt: AtomicLong = new AtomicLong(0)
11+
private val allocatedBytes: AtomicLong = new AtomicLong(0)
12+
private val peakAllocatedBytes: AtomicLong = new AtomicLong(0)
13+
private val unpooledAllocatedBytes: AtomicLong = new AtomicLong(0)
14+
private var fixedBufferSize: Long = 0
15+
private val allocatedBufferPool: Stack[ByteBuf] = Stack[ByteBuf]()
16+
private var reachRead = false
17+
private val allocator = UnpooledByteBufAllocator.DEFAULT
18+
19+
def allocateNewBuffer(bufSize: Int): ByteBuf = synchronized {
20+
if (fixedBufferSize == 0) {
21+
fixedBufferSize = bufSize
22+
} else if (bufSize > fixedBufferSize) {
23+
throw new RuntimeException(s"allocateNewBuffer, expected size is ${fixedBufferSize}, actual size is ${bufSize}")
24+
}
25+
allocatedBufRenCnt.getAndIncrement()
26+
allocatedBytes.getAndAdd(bufSize)
27+
if (allocatedBytes.get > peakAllocatedBytes.get) {
28+
peakAllocatedBytes.set(allocatedBytes.get)
29+
}
30+
try {
31+
/*if (allocatedBufferPool.isEmpty == false) {
32+
allocatedBufferPool.pop
33+
} else {
34+
allocator.directBuffer(bufSize, bufSize)
35+
}*/
36+
allocator.directBuffer(bufSize, bufSize)
37+
} catch {
38+
case e : Throwable =>
39+
logError(s"allocateNewBuffer size is ${bufSize}")
40+
throw e
41+
}
42+
}
43+
44+
def releaseBuffer(buf: ByteBuf): Unit = synchronized {
45+
allocatedBufRenCnt.getAndDecrement()
46+
allocatedBytes.getAndAdd(0 - fixedBufferSize)
47+
buf.clear()
48+
//allocatedBufferPool.push(buf)
49+
buf.release(buf.refCnt())
50+
}
51+
52+
def unpooledInc(bufSize: Int): Unit = synchronized {
53+
if (reachRead == false) {
54+
reachRead = true
55+
peakAllocatedBytes.set(0)
56+
}
57+
unpooledAllocatedBytes.getAndAdd(bufSize)
58+
}
59+
60+
def unpooledDec(bufSize: Int): Unit = synchronized {
61+
unpooledAllocatedBytes.getAndAdd(0 - bufSize)
62+
}
63+
64+
def unpooledInc(bufSize: Long): Unit = synchronized {
65+
if (reachRead == false) {
66+
reachRead = true
67+
peakAllocatedBytes.set(0)
68+
}
69+
unpooledAllocatedBytes.getAndAdd(bufSize)
70+
}
71+
72+
def unpooledDec(bufSize: Long): Unit = synchronized {
73+
unpooledAllocatedBytes.getAndAdd(0 - bufSize)
74+
}
75+
76+
override def toString(): String = synchronized {
77+
return s"NettyBufferPool [refCnt|allocatedBytes|Peak|Native] is [${allocatedBufRenCnt.get}|${allocatedBytes.get}|${peakAllocatedBytes.get}|${unpooledAllocatedBytes.get}]"
78+
}
79+
}

core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockInputStream.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ class PmemBlockInputStream[K, C](pmemBlockOutputStream: PmemBlockOutputStream, s
1111
val serInstance: SerializerInstance = serializer.newInstance()
1212
val persistentMemoryWriter: PersistentMemoryHandler = PersistentMemoryHandler.getPersistentMemoryHandler
1313
var pmemInputStream: PmemInputStream = new PmemInputStream(persistentMemoryWriter, blockId.name)
14-
var inObjStream: DeserializationStream = serInstance.deserializeStream(pmemInputStream)
14+
val wrappedStream = serializerManager.wrapStream(blockId, pmemInputStream)
15+
var inObjStream: DeserializationStream = serInstance.deserializeStream(wrappedStream)
1516

1617
var total_records: Long = 0
1718
var indexInBatch: Int = 0
@@ -45,6 +46,7 @@ class PmemBlockInputStream[K, C](pmemBlockOutputStream: PmemBlockOutputStream, s
4546
}
4647

4748
def close(): Unit = {
49+
inObjStream.close
4850
pmemInputStream.close
4951
inObjStream = null
5052
}

core/src/main/scala/org/apache/spark/storage/pmof/PmemBlockOutputStream.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,10 @@ private[spark] class PmemBlockOutputStream(
5252
//persistentMemoryWriter.updateShuffleMeta(blockId.name)
5353

5454
val pmemOutputStream: PmemOutputStream = new PmemOutputStream(
55-
persistentMemoryWriter, numPartitions, blockId.name, numMaps)
55+
persistentMemoryWriter, numPartitions, blockId.name, numMaps, (pmofConf.spill_throttle.toInt + 1024))
5656
val serInstance = serializer.newInstance()
57-
var objStream: SerializationStream = serInstance.serializeStream(pmemOutputStream)
57+
val bs = serializerManager.wrapStream(blockId, pmemOutputStream)
58+
var objStream: SerializationStream = serInstance.serializeStream(bs)
5859

5960
override def write(key: Any, value: Any): Unit = {
6061
objStream.writeKey(key)
@@ -68,12 +69,16 @@ private[spark] class PmemBlockOutputStream(
6869
}
6970

7071
override def close() {
72+
if (objStream != null) {
73+
objStream.close()
74+
objStream = null
75+
}
7176
pmemOutputStream.close()
72-
objStream = null
7377
}
7478

7579
override def flush() {
7680
objStream.flush()
81+
bs.flush()
7782
}
7883

7984
def maybeSpill(force: Boolean = false): Unit = {

core/src/main/scala/org/apache/spark/storage/pmof/PmemManagedBuffer.scala

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@ import java.nio.ByteBuffer
55
import sun.misc.Cleaner
66
import io.netty.buffer.Unpooled
77
import java.util.concurrent.atomic.AtomicInteger
8+
import io.netty.buffer.ByteBuf
89

910
import org.apache.spark.internal.Logging
1011
import org.apache.spark.network.buffer.ManagedBuffer
1112

1213
class PmemManagedBuffer(pmHandler: PersistentMemoryHandler, blockId: String) extends ManagedBuffer with Logging {
1314
var inputStream: InputStream = _
1415
var total_size: Long = -1
16+
var buf: ByteBuf = _
1517
var byteBuffer: ByteBuffer = _
1618
private val refCount = new AtomicInteger(1)
1719

@@ -26,8 +28,13 @@ class PmemManagedBuffer(pmHandler: PersistentMemoryHandler, blockId: String) ext
2628
// TODO: This function should be Deprecated by spark in near future.
2729
val data_length = size().toInt
2830
val in = createInputStream()
29-
byteBuffer = ByteBuffer.allocateDirect(data_length)
3031
val data = Array.ofDim[Byte](data_length)
32+
if (buf == null) {
33+
buf = NettyByteBufferPool.allocateNewBuffer(data_length)
34+
byteBuffer = buf.nioBuffer(0, data_length)
35+
} else {
36+
byteBuffer.clear()
37+
}
3138
in.read(data)
3239
byteBuffer.put(data)
3340
byteBuffer.flip()
@@ -48,12 +55,15 @@ class PmemManagedBuffer(pmHandler: PersistentMemoryHandler, blockId: String) ext
4855

4956
override def release(): ManagedBuffer = {
5057
if (refCount.decrementAndGet() == 0) {
51-
if (byteBuffer != null) {
58+
if (buf != null) {
59+
NettyByteBufferPool.releaseBuffer(buf)
60+
}
61+
/*if (byteBuffer != null) {
5262
val cleanerField: java.lang.reflect.Field = byteBuffer.getClass.getDeclaredField("cleaner")
5363
cleanerField.setAccessible(true)
5464
val cleaner: Cleaner = cleanerField.get(byteBuffer).asInstanceOf[Cleaner]
5565
cleaner.clean()
56-
}
66+
}*/
5767
if (inputStream != null) {
5868
inputStream.close()
5969
}
@@ -62,8 +72,8 @@ class PmemManagedBuffer(pmHandler: PersistentMemoryHandler, blockId: String) ext
6272
}
6373

6474
override def convertToNetty(): Object = {
65-
val data_length = size().toInt
6675
val in = createInputStream()
76+
val data_length = size().toInt
6777
Unpooled.wrappedBuffer(in.asInstanceOf[PmemInputStream].getByteBufferDirectAddr, data_length, false)
6878
}
6979
}

core/src/main/scala/org/apache/spark/storage/pmof/PmemOutputStream.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,16 @@ class PmemOutputStream(
1010
persistentMemoryWriter: PersistentMemoryHandler,
1111
numPartitions: Int,
1212
blockId: String,
13-
numMaps: Int
13+
numMaps: Int,
14+
bufferSize: Int
1415
) extends OutputStream with Logging {
1516
var set_clean = true
1617
var is_closed = false
1718

18-
val length: Int = 1024*1024*6
19+
val length: Int = bufferSize
1920
var bufferFlushedSize: Int = 0
2021
var bufferRemainingSize: Int = 0
21-
val buf: ByteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(length, length)
22+
val buf: ByteBuf = NettyByteBufferPool.allocateNewBuffer(length)
2223
val byteBuffer: ByteBuffer = buf.nioBuffer(0, length)
2324

2425
override def write(bytes: Array[Byte], off: Int, len: Int): Unit = {
@@ -60,7 +61,7 @@ class PmemOutputStream(
6061
if (!is_closed) {
6162
flush()
6263
reset()
63-
buf.release()
64+
NettyByteBufferPool.releaseBuffer(buf)
6465
is_closed = true
6566
}
6667
}

core/src/main/scala/org/apache/spark/util/collection/pmof/PmemExternalSorter.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,12 @@ private[spark] class PmemExternalSorter[K, V, C](
103103
if (cur_partitionId != partitionId) {
104104
if (cur_partitionId != -1) {
105105
buffer.maybeSpill(true)
106+
buffer.close()
107+
buffer = null
106108
}
107109
cur_partitionId = partitionId
108110
buffer = getPartitionByteBufferArray(dep.shuffleId, cur_partitionId)
111+
logDebug(s"${dep.shuffleId}_${cur_partitionId} ${NettyByteBufferPool}")
109112
}
110113
require(partitionId >= 0 && partitionId < numPartitions,
111114
s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})")
@@ -115,6 +118,8 @@ private[spark] class PmemExternalSorter[K, V, C](
115118
}
116119
if (buffer != null) {
117120
buffer.maybeSpill(true)
121+
buffer.close()
122+
buffer = null
118123
}
119124
}
120125

native/src/PmemBuffer.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include <cstring>
77
using namespace std;
88

9-
#define DEFAULT_BUFSIZE 4096*1024+512
9+
#define DEFAULT_BUFSIZE 2049 * 1024
1010

1111
class PmemBuffer {
1212
public:
@@ -17,6 +17,7 @@ class PmemBuffer {
1717
pos = 0;
1818
pos_dirty = 0;
1919
}
20+
2021
explicit PmemBuffer(long initial_buf_data_capacity) {
2122
buf_data_capacity = initial_buf_data_capacity;
2223
buf_data = (char*)malloc(sizeof(char) * buf_data_capacity);
@@ -39,11 +40,11 @@ class PmemBuffer {
3940
std::lock_guard<std::mutex> lock(buffer_mtx);
4041
if (buf_data_capacity == 0 && pmem_data_len > 0) {
4142
buf_data_capacity = pmem_data_len;
42-
buf_data = (char*)malloc(sizeof(char) * pmem_data_len);
43+
buf_data = (char*)malloc(sizeof(char) * buf_data_capacity);
4344
}
4445

4546
if (remaining > 0) {
46-
if (buf_data_capacity < remaining+pmem_data_len) {
47+
if (buf_data_capacity < remaining + pmem_data_len) {
4748
buf_data_capacity = remaining + pmem_data_len;
4849
char* tmp_buf_data = buf_data;
4950
buf_data = (char*)malloc(sizeof(char) * buf_data_capacity);
@@ -118,6 +119,10 @@ class PmemBuffer {
118119
return read_len;
119120
}
120121

122+
char* getDataAddr() {
123+
return buf_data;
124+
}
125+
121126
int write(char* data, int len) {
122127
std::lock_guard<std::mutex> lock(buffer_mtx);
123128
if (buf_data_capacity == 0) {
@@ -149,9 +154,6 @@ class PmemBuffer {
149154
return 0;
150155
}
151156

152-
char* getDataAddr() {
153-
return buf_data;
154-
}
155157

156158
private:
157159
mutex buffer_mtx;

0 commit comments

Comments
 (0)