UnityVulkanInitCallback.cpp 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. #include "pch.h"
  2. #include "UnityVulkanInitCallback.h"
  3. namespace unity
  4. {
  5. namespace webrtc
  6. {
  7. const std::vector<const char*> requestedInstanceExtensions = {
  8. // VK_EXT_DEBUG_REPORT_EXTENSION_NAME,
  9. VK_KHR_GET_PHYSICAL_DEVICE_PROPERTIES_2_EXTENSION_NAME,
  10. VK_KHR_EXTERNAL_MEMORY_CAPABILITIES_EXTENSION_NAME,
  11. VK_KHR_EXTERNAL_SEMAPHORE_CAPABILITIES_EXTENSION_NAME
  12. };
  13. static std::vector<const char*> requestedDeviceExtensions = {
  14. VK_KHR_EXTERNAL_MEMORY_EXTENSION_NAME,
  15. VK_KHR_EXTERNAL_SEMAPHORE_EXTENSION_NAME,
  16. #ifdef UNITY_LINUX
  17. VK_KHR_EXTERNAL_MEMORY_FD_EXTENSION_NAME,
  18. VK_KHR_EXTERNAL_SEMAPHORE_FD_EXTENSION_NAME
  19. #elif UNITY_WIN
  20. VK_KHR_EXTERNAL_MEMORY_WIN32_EXTENSION_NAME,
  21. VK_KHR_EXTERNAL_SEMAPHORE_WIN32_EXTENSION_NAME
  22. #endif
  23. };
  24. static VKAPI_ATTR VkResult VKAPI_CALL Hook_vkCreateDevice(
  25. VkPhysicalDevice physicalDevice,
  26. const VkDeviceCreateInfo* pCreateInfo,
  27. const VkAllocationCallbacks* pAllocator,
  28. VkDevice* pDevice)
  29. {
  30. // copy value
  31. VkDeviceCreateInfo newCreateInfo = *pCreateInfo;
  32. // copy extension name list
  33. std::vector<const char*> enabledExtensions;
  34. enabledExtensions.reserve(pCreateInfo->enabledExtensionCount);
  35. for (uint32_t i = 0; i < pCreateInfo->enabledExtensionCount; i++)
  36. {
  37. enabledExtensions.push_back(newCreateInfo.ppEnabledExtensionNames[i]);
  38. }
  39. // get the union of the two
  40. std::vector<const char*> newExtensions;
  41. std::set_union(
  42. requestedDeviceExtensions.begin(),
  43. requestedDeviceExtensions.end(),
  44. enabledExtensions.begin(),
  45. enabledExtensions.end(),
  46. std::inserter(newExtensions, std::end(newExtensions)));
  47. RTC_LOG(LS_INFO) << "WebRTC plugin intercepts vkCreateDevice.";
  48. for (auto extension : newExtensions)
  49. {
  50. RTC_LOG(LS_INFO) << "[Vulkan init intercept] extensions: name=" << extension;
  51. }
  52. // replace extension name list
  53. newCreateInfo.ppEnabledExtensionNames = newExtensions.data();
  54. newCreateInfo.enabledExtensionCount = static_cast<uint32_t>(newExtensions.size());
  55. VkResult result = vkCreateDevice(physicalDevice, &newCreateInfo, pAllocator, pDevice);
  56. if (result != VK_SUCCESS)
  57. {
  58. RTC_LOG(LS_ERROR) << "vkCreateDevice failed. error:" << result;
  59. return result;
  60. }
  61. if (!LoadDeviceVulkanFunction(*pDevice))
  62. return VK_ERROR_INITIALIZATION_FAILED;
  63. return result;
  64. }
  65. static VKAPI_ATTR VkResult VKAPI_CALL Hook_vkCreateInstance(
  66. const VkInstanceCreateInfo* pCreateInfo, const VkAllocationCallbacks* pAllocator, VkInstance* pInstance)
  67. {
  68. if (!LoadGlobalVulkanFunction())
  69. return VK_ERROR_INITIALIZATION_FAILED;
  70. // copy value
  71. VkInstanceCreateInfo newCreateInfo = *pCreateInfo;
  72. // copy extension name list
  73. std::vector<const char*> enabledExtensions;
  74. enabledExtensions.reserve(pCreateInfo->enabledExtensionCount);
  75. for (uint32_t i = 0; i < pCreateInfo->enabledExtensionCount; i++)
  76. {
  77. enabledExtensions.push_back(newCreateInfo.ppEnabledExtensionNames[i]);
  78. }
  79. // get the union of the two
  80. std::vector<const char*> newExtensions;
  81. std::set_union(
  82. requestedInstanceExtensions.begin(),
  83. requestedInstanceExtensions.end(),
  84. enabledExtensions.begin(),
  85. enabledExtensions.end(),
  86. std::inserter(newExtensions, std::end(newExtensions)));
  87. RTC_LOG(LS_INFO) << "WebRTC plugin intercepts vkCreateInstance.";
  88. for (auto extension : newExtensions)
  89. {
  90. RTC_LOG(LS_INFO) << "[Vulkan init intercept] extensions: name=" << extension;
  91. }
  92. // replace extension name list
  93. newCreateInfo.ppEnabledExtensionNames = newExtensions.data();
  94. newCreateInfo.enabledExtensionCount = static_cast<uint32_t>(newExtensions.size());
  95. VkResult result = vkCreateInstance(&newCreateInfo, pAllocator, pInstance);
  96. if (result != VK_SUCCESS)
  97. {
  98. RTC_LOG(LS_ERROR) << "vkCreateInstance failed. error:" << result;
  99. return result;
  100. }
  101. if (!LoadInstanceVulkanFunction(*pInstance))
  102. return VK_ERROR_INITIALIZATION_FAILED;
  103. return result;
  104. }
  105. static VKAPI_ATTR PFN_vkVoidFunction VKAPI_CALL
  106. Hook_vkGetInstanceProcAddr(VkInstance instance, const char* funcName)
  107. {
  108. if (!funcName)
  109. return nullptr;
  110. std::string strFuncName = funcName;
  111. if (strFuncName == "vkCreateInstance")
  112. {
  113. return reinterpret_cast<PFN_vkVoidFunction>(&Hook_vkCreateInstance);
  114. }
  115. if (strFuncName == "vkCreateDevice")
  116. {
  117. return reinterpret_cast<PFN_vkVoidFunction>(&Hook_vkCreateDevice);
  118. }
  119. return vkGetInstanceProcAddr(instance, funcName);
  120. }
  121. PFN_vkGetInstanceProcAddr
  122. InterceptVulkanInitialization(PFN_vkGetInstanceProcAddr getInstanceProcAddr, void* userdata)
  123. {
  124. vkGetInstanceProcAddr = getInstanceProcAddr;
  125. return Hook_vkGetInstanceProcAddr;
  126. }
  127. }
  128. }