CudaContext.cpp 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. #include "pch.h"
  2. #include "CudaContext.h"
  3. #if SUPPORT_VULKAN
  4. #include "GraphicsDevice/Vulkan/VulkanUtility.h"
  5. #endif
  6. #if SUPPORT_D3D11
  7. using namespace Microsoft::WRL;
  8. #endif
  9. namespace unity
  10. {
  11. namespace webrtc
  12. {
  13. static void* s_hModule = nullptr;
  14. static bool FindModule()
  15. {
  16. if (s_hModule)
  17. return true;
  18. #if UNITY_WIN
  19. // dll delay load
  20. HMODULE module = LoadLibrary(TEXT("nvcuda.dll"));
  21. if (!module)
  22. {
  23. RTC_LOG(LS_INFO) << "nvcuda.dll is not found.";
  24. return false;
  25. }
  26. s_hModule = module;
  27. #elif UNITY_LINUX
  28. s_hModule = dlopen("libcuda.so.1", RTLD_LAZY | RTLD_GLOBAL);
  29. if (!s_hModule)
  30. return false;
  31. // Close handle immediately because going to call `dlopen` again
  32. // in the implib module when cuda api called on Linux.
  33. dlclose(s_hModule);
  34. s_hModule = nullptr;
  35. #endif
  36. return true;
  37. }
  38. static CUresult CheckDriverVersion()
  39. {
  40. int driverVersion = 0;
  41. CUresult result = cuDriverGetVersion(&driverVersion);
  42. if (result != CUDA_SUCCESS)
  43. {
  44. return result;
  45. }
  46. if (kRequiredDriverVersion > driverVersion)
  47. {
  48. RTC_LOG(LS_ERROR) << "CUDA driver version is not higher than the required version. " << driverVersion;
  49. return CUDA_ERROR_NO_DEVICE;
  50. }
  51. return CUDA_SUCCESS;
  52. }
  53. CudaContext::CudaContext()
  54. : m_context(nullptr)
  55. {
  56. }
  57. CUresult CudaContext::FindCudaDevice(const uint8_t* uuid, CUdevice* cuDevice)
  58. {
  59. bool found = FindModule();
  60. if (!found)
  61. return CUDA_ERROR_NO_DEVICE;
  62. CUdevice _cuDevice = 0;
  63. CUresult result = CUDA_SUCCESS;
  64. int numDevices = 0;
  65. result = cuDeviceGetCount(&numDevices);
  66. if (result != CUDA_SUCCESS)
  67. {
  68. return result;
  69. }
  70. CUuuid id = {};
  71. // Loop over the available devices and identify the CUdevice corresponding to the physical device in use by
  72. // this Vulkan instance. This is required because there is no other way to match GPUs across API boundaries.
  73. for (int i = 0; i < numDevices; i++)
  74. {
  75. result = cuDeviceGet(&_cuDevice, i);
  76. if (result != CUDA_SUCCESS)
  77. {
  78. return result;
  79. }
  80. result = cuDeviceGetUuid(&id, _cuDevice);
  81. if (result != CUDA_SUCCESS)
  82. {
  83. return result;
  84. }
  85. if (!std::memcmp(static_cast<const void*>(&id), static_cast<const void*>(uuid), sizeof(CUuuid)))
  86. {
  87. if (cuDevice != nullptr)
  88. *cuDevice = _cuDevice;
  89. return CUDA_SUCCESS;
  90. }
  91. }
  92. return CUDA_ERROR_NO_DEVICE;
  93. }
  94. CUresult CudaContext::Init(const VkInstance instance, VkPhysicalDevice physicalDevice)
  95. {
  96. // dll check
  97. bool found = FindModule();
  98. if (!found)
  99. {
  100. return CUDA_ERROR_NOT_FOUND;
  101. }
  102. CUresult result = CheckDriverVersion();
  103. if (result != CUDA_SUCCESS)
  104. {
  105. return result;
  106. }
  107. CUdevice cuDevice = 0;
  108. result = cuInit(0);
  109. if (result != CUDA_SUCCESS)
  110. {
  111. return result;
  112. }
  113. int numDevices = 0;
  114. result = cuDeviceGetCount(&numDevices);
  115. if (result != CUDA_SUCCESS)
  116. {
  117. return result;
  118. }
  119. CUuuid id = {};
  120. std::array<uint8_t, VK_UUID_SIZE> deviceUUID;
  121. if (!VulkanUtility::GetPhysicalDeviceUUIDInto(instance, physicalDevice, &deviceUUID))
  122. {
  123. return CUDA_ERROR_INVALID_DEVICE;
  124. }
  125. // Loop over the available devices and identify the CUdevice corresponding to the physical device in use by
  126. // this Vulkan instance. This is required because there is no other way to match GPUs across API boundaries.
  127. bool foundDevice = false;
  128. for (int i = 0; i < numDevices; i++)
  129. {
  130. cuDeviceGet(&cuDevice, i);
  131. cuDeviceGetUuid(&id, cuDevice);
  132. if (!std::memcmp(
  133. static_cast<const void*>(&id), static_cast<const void*>(deviceUUID.data()), sizeof(CUuuid)))
  134. {
  135. foundDevice = true;
  136. break;
  137. }
  138. }
  139. if (!foundDevice)
  140. {
  141. return CUDA_ERROR_NO_DEVICE;
  142. }
  143. result = cuCtxCreate(&m_context, 0, cuDevice);
  144. return result;
  145. }
  146. //---------------------------------------------------------------------------------------------------------------------
  147. #if defined(SUPPORT_D3D11)
  148. CUresult CudaContext::Init(ID3D11Device* device)
  149. {
  150. bool found = FindModule();
  151. if (!found)
  152. {
  153. return CUDA_ERROR_NOT_FOUND;
  154. }
  155. CUresult result = CheckDriverVersion();
  156. if (result != CUDA_SUCCESS)
  157. {
  158. return result;
  159. }
  160. result = cuInit(0);
  161. if (result != CUDA_SUCCESS)
  162. {
  163. return result;
  164. }
  165. int numDevices = 0;
  166. result = cuDeviceGetCount(&numDevices);
  167. if (result != CUDA_SUCCESS)
  168. {
  169. return result;
  170. }
  171. ComPtr<IDXGIDevice> pDxgiDevice = nullptr;
  172. if (device->QueryInterface(IID_PPV_ARGS(&pDxgiDevice)) != S_OK)
  173. {
  174. return CUDA_ERROR_NO_DEVICE;
  175. }
  176. ComPtr<IDXGIAdapter> pDxgiAdapter = nullptr;
  177. if (pDxgiDevice->GetAdapter(&pDxgiAdapter) != S_OK)
  178. {
  179. return CUDA_ERROR_NO_DEVICE;
  180. }
  181. CUdevice dev;
  182. if (cuD3D11GetDevice(&dev, pDxgiAdapter.Get()) != CUDA_SUCCESS)
  183. {
  184. return CUDA_ERROR_NO_DEVICE;
  185. }
  186. result = cuCtxCreate(&m_context, 0, dev);
  187. return result;
  188. }
  189. #endif
  190. //---------------------------------------------------------------------------------------------------------------------
  191. #if defined(SUPPORT_D3D12)
  192. CUresult CudaContext::Init(ID3D12Device* device)
  193. {
  194. bool found = FindModule();
  195. if (!found)
  196. {
  197. return CUDA_ERROR_NOT_FOUND;
  198. }
  199. CUresult result = CheckDriverVersion();
  200. if (result != CUDA_SUCCESS)
  201. {
  202. return result;
  203. }
  204. result = cuInit(0);
  205. if (result != CUDA_SUCCESS)
  206. {
  207. return result;
  208. }
  209. int numDevices = 0;
  210. result = cuDeviceGetCount(&numDevices);
  211. if (result != CUDA_SUCCESS)
  212. {
  213. return result;
  214. }
  215. LUID luid = device->GetAdapterLuid();
  216. CUdevice cuDevice = 0;
  217. bool deviceFound = false;
  218. for (int32_t deviceIndex = 0; deviceIndex < numDevices; deviceIndex++)
  219. {
  220. result = cuDeviceGet(&cuDevice, deviceIndex);
  221. if (result != CUDA_SUCCESS)
  222. {
  223. return result;
  224. }
  225. char luid_[8];
  226. unsigned int nodeMask;
  227. result = cuDeviceGetLuid(luid_, &nodeMask, cuDevice);
  228. if (result != CUDA_SUCCESS)
  229. {
  230. return result;
  231. }
  232. if (memcmp(&luid.LowPart, luid_, sizeof(luid.LowPart)) == 0 &&
  233. memcmp(&luid.HighPart, luid_ + sizeof(luid.LowPart), sizeof(luid.HighPart)) == 0)
  234. {
  235. deviceFound = true;
  236. break;
  237. }
  238. }
  239. if (!deviceFound)
  240. return CUDA_ERROR_NO_DEVICE;
  241. return cuCtxCreate(&m_context, 0, cuDevice);
  242. }
  243. #endif
  244. //---------------------------------------------------------------------------------------------------------------------
  245. // todo(kazuki):: not supported on windows
  246. #if defined(SUPPORT_OPENGL_UNIFIED) && defined(UNITY_LINUX)
  247. CUresult CudaContext::InitGL()
  248. {
  249. // dll check
  250. bool found = FindModule();
  251. if (!found)
  252. {
  253. return CUDA_ERROR_NOT_FOUND;
  254. }
  255. CUresult result = CheckDriverVersion();
  256. if (result != CUDA_SUCCESS)
  257. {
  258. return result;
  259. }
  260. result = cuInit(0);
  261. if (result != CUDA_SUCCESS)
  262. {
  263. return result;
  264. }
  265. int numDevices;
  266. result = cuDeviceGetCount(&numDevices);
  267. if (CUDA_SUCCESS != result)
  268. {
  269. return result;
  270. }
  271. if (numDevices == 0)
  272. {
  273. return CUDA_ERROR_NO_DEVICE;
  274. }
  275. // TODO:: check GPU capability
  276. int cuDevId = 0;
  277. CUdevice cuDevice = 0;
  278. result = cuDeviceGet(&cuDevice, cuDevId);
  279. if (CUDA_SUCCESS != result)
  280. {
  281. return result;
  282. }
  283. result = cuCtxCreate(&m_context, 0, cuDevice);
  284. if (CUDA_SUCCESS != result)
  285. {
  286. return result;
  287. }
  288. return CUDA_SUCCESS;
  289. }
  290. #endif
  291. //---------------------------------------------------------------------------------------------------------------------
  292. CUcontext CudaContext::GetContext() const
  293. {
  294. RTC_DCHECK(m_context);
  295. CUcontext current;
  296. if (cuCtxGetCurrent(&current) != CUDA_SUCCESS)
  297. {
  298. throw;
  299. }
  300. if (m_context == current)
  301. {
  302. return m_context;
  303. }
  304. if (cuCtxSetCurrent(m_context) != CUDA_SUCCESS)
  305. {
  306. throw;
  307. }
  308. return m_context;
  309. }
  310. void CudaContext::Shutdown()
  311. {
  312. if (m_context)
  313. {
  314. cuCtxDestroy(m_context);
  315. m_context = nullptr;
  316. }
  317. if (s_hModule)
  318. {
  319. #if UNITY_WIN
  320. FreeLibrary((HMODULE)s_hModule);
  321. #elif UNITY_LINUX
  322. dlclose(s_hModule);
  323. #endif
  324. s_hModule = nullptr;
  325. }
  326. }
  327. } // end namespace webrtc
  328. } // end namespace unity