diff --git a/src/provider/provider_level_zero.c b/src/provider/provider_level_zero.c index f5916682f..b75c14094 100644 --- a/src/provider/provider_level_zero.c +++ b/src/provider/provider_level_zero.c @@ -107,6 +107,8 @@ typedef struct ze_ops_t { ze_result_t (*zeContextMakeMemoryResident)(ze_context_handle_t, ze_device_handle_t, void *, size_t); + ze_result_t (*zeContextEvictMemory)(ze_context_handle_t, ze_device_handle_t, + void *, size_t); ze_result_t (*zeDeviceGetProperties)(ze_device_handle_t, ze_device_properties_t *); ze_result_t (*zeMemFreeExt)(ze_context_handle_t, @@ -218,6 +220,8 @@ static void init_ze_global_state(void) { utils_get_symbol_addr(lib_handle, "zeMemCloseIpcHandle", lib_name); *(void **)&g_ze_ops.zeContextMakeMemoryResident = utils_get_symbol_addr( lib_handle, "zeContextMakeMemoryResident", lib_name); + *(void **)&g_ze_ops.zeContextEvictMemory = + utils_get_symbol_addr(lib_handle, "zeContextEvictMemory", lib_name); *(void **)&g_ze_ops.zeDeviceGetProperties = utils_get_symbol_addr(lib_handle, "zeDeviceGetProperties", lib_name); *(void **)&g_ze_ops.zeMemFreeExt = @@ -230,7 +234,8 @@ static void init_ze_global_state(void) { !g_ze_ops.zeMemGetIpcHandle || !g_ze_ops.zeMemOpenIpcHandle || !g_ze_ops.zeMemCloseIpcHandle || !g_ze_ops.zeContextMakeMemoryResident || - !g_ze_ops.zeDeviceGetProperties || !g_ze_ops.zeMemGetAllocProperties) { + !g_ze_ops.zeContextEvictMemory || !g_ze_ops.zeDeviceGetProperties || + !g_ze_ops.zeMemGetAllocProperties) { // g_ze_ops.zeMemPutIpcHandle can be NULL because it was introduced // starting from Level Zero 1.6 LOG_FATAL("Required Level Zero symbols not found."); @@ -1012,8 +1017,9 @@ static int ze_memory_provider_resident_device_change_helper(uintptr_t key, change_data->source_memory_provider->context, change_data->peer_device, info->props.base, info->props.base_size); } else { - result = ZE_RESULT_SUCCESS; - // TODO: currently not implemented call evict here + result = g_ze_ops.zeContextEvictMemory( + change_data->source_memory_provider->context, + change_data->peer_device, info->props.base, info->props.base_size); } if (result != ZE_RESULT_SUCCESS) { diff --git a/test/common/level_zero_mocks.h b/test/common/level_zero_mocks.h index d5fcb22e9..35a5e2b7a 100644 --- a/test/common/level_zero_mocks.h +++ b/test/common/level_zero_mocks.h @@ -64,6 +64,9 @@ class LevelZeroMock : public LevelZero { MOCK_METHOD4(zeContextMakeMemoryResident, ze_result_t(ze_context_handle_t, ze_device_handle_t, void *, size_t)); + MOCK_METHOD4(zeContextEvictMemory, + ze_result_t(ze_context_handle_t, ze_device_handle_t, void *, + size_t)); MOCK_METHOD2(zeMemFree, ze_result_t(ze_context_handle_t hContext, void *ptr)); diff --git a/test/common/ze_loopback.cpp b/test/common/ze_loopback.cpp index a5a10aa3b..d43219a3c 100644 --- a/test/common/ze_loopback.cpp +++ b/test/common/ze_loopback.cpp @@ -263,6 +263,13 @@ ze_result_t ZE_APICALL zeContextMakeMemoryResident(ze_context_handle_t hContext, size); } +ze_result_t ZE_APICALL zeContextEvictMemory(ze_context_handle_t hContext, + ze_device_handle_t hDevice, + void *ptr, size_t size) { + check_mock_present(); + return level_zero_mock->zeContextEvictMemory(hContext, hDevice, ptr, size); +} + ze_result_t ZE_APICALL zeMemFreeExt(ze_context_handle_t hContext, const ze_memory_free_ext_desc_t *pMemFreeDesc, void *ptr) { diff --git a/test/common/ze_loopback.def b/test/common/ze_loopback.def index 0b13bab8c..2dc73ab31 100644 --- a/test/common/ze_loopback.def +++ b/test/common/ze_loopback.def @@ -24,6 +24,7 @@ EXPORTS zeCommandListAppendMemoryCopy zeCommandListAppendMemoryFill zeContextMakeMemoryResident + zeContextEvictMemory zeMemGetAllocProperties zeMemAllocDevice zeMemAllocHost diff --git a/test/common/ze_loopback.h b/test/common/ze_loopback.h index 2bfa441be..bbf95f141 100644 --- a/test/common/ze_loopback.h +++ b/test/common/ze_loopback.h @@ -27,6 +27,9 @@ class LevelZero { virtual ze_result_t zeContextMakeMemoryResident(ze_context_handle_t, ze_device_handle_t, void *, size_t) = 0; + virtual ze_result_t zeContextEvictMemory(ze_context_handle_t, + ze_device_handle_t, void *, + size_t) = 0; virtual ze_result_t zeMemFree(ze_context_handle_t hContext, void *ptr) = 0; }; diff --git a/test/common/ze_loopback.map b/test/common/ze_loopback.map index 08782782f..35c6ac909 100644 --- a/test/common/ze_loopback.map +++ b/test/common/ze_loopback.map @@ -23,6 +23,7 @@ zeCommandListAppendMemoryCopy; zeCommandListAppendMemoryFill; zeContextMakeMemoryResident; + zeContextEvictMemory; zeMemGetAllocProperties; zeMemAllocDevice; zeMemAllocHost; diff --git a/test/pools/pool_residency.cpp b/test/pools/pool_residency.cpp index 806faef1a..c1f60f874 100644 --- a/test/pools/pool_residency.cpp +++ b/test/pools/pool_residency.cpp @@ -109,6 +109,32 @@ TEST_F(PoolResidencyTestFixture, umfPoolFree(pool, ptr); } +TEST_F(PoolResidencyTestFixture, + existingAllocationsShouldBeEvictedFromRemovedDevice) { + initializeMemoryPool(l0mock.initializeMemoryProviderWithResidentDevices( + OUR_DEVICE, {DEVICE_2, DEVICE_3})); + + EXPECT_CALL(l0mock, zeMemAllocDevice(CONTEXT, _, _, _, OUR_DEVICE, _)) + .WillOnce( + DoAll(SetArgPointee<5>(POINTER_0), Return(ZE_RESULT_SUCCESS))); + EXPECT_CALL(l0mock, zeContextMakeMemoryResident(CONTEXT, DEVICE_2, _, _)) + .WillOnce(Return(ZE_RESULT_SUCCESS)); + EXPECT_CALL(l0mock, zeContextMakeMemoryResident(CONTEXT, DEVICE_3, _, _)) + .WillOnce(Return(ZE_RESULT_SUCCESS)); + + void *ptr = umfPoolMalloc(pool, 123); + EXPECT_EQ(ptr, POINTER_0); + + EXPECT_CALL(l0mock, zeContextEvictMemory(CONTEXT, DEVICE_2, _, _)) + .WillOnce(Return(ZE_RESULT_SUCCESS)); + + umf_memory_provider_handle_t provider = nullptr; + EXPECT_EQ(umfPoolGetMemoryProvider(pool, &provider), UMF_RESULT_SUCCESS); + umfLevelZeroMemoryProviderResidentDeviceChange(provider, DEVICE_2, false); + + umfPoolFree(pool, ptr); +} + TEST_F(PoolResidencyTestFixture, allocationShouldNotBeMadeResidentOnRemovedDevice) { initializeMemoryPool(l0mock.initializeMemoryProviderWithResidentDevices(