Context.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. #include "pch.h"
  2. #include <api/create_peerconnection_factory.h>
  3. #include <api/task_queue/default_task_queue_factory.h>
  4. #include <rtc_base/ssl_adapter.h>
  5. #include <rtc_base/strings/json.h>
  6. #include "AudioTrackSinkAdapter.h"
  7. #include "Context.h"
  8. #include "GraphicsDevice/GraphicsUtility.h"
  9. #include "GraphicsDevice/IGraphicsDevice.h"
  10. #include "MediaStreamObserver.h"
  11. #include "SetSessionDescriptionObserver.h"
  12. #include "UnityAudioDecoderFactory.h"
  13. #include "UnityAudioEncoderFactory.h"
  14. #include "UnityAudioTrackSource.h"
  15. #include "UnityVideoDecoderFactory.h"
  16. #include "UnityVideoEncoderFactory.h"
  17. #include "UnityVideoTrackSource.h"
  18. #include "WebRTCPlugin.h"
  19. #if CUDA_PLATFORM
  20. #include "Logger.h"
  21. simplelogger::Logger* logger = simplelogger::LoggerFactory::CreateConsoleLogger();
  22. #endif
  23. using namespace ::webrtc;
  24. namespace unity
  25. {
  26. namespace webrtc
  27. {
  28. std::unique_ptr<ContextManager> ContextManager::s_instance;
  29. ContextManager* ContextManager::GetInstance()
  30. {
  31. if (s_instance == nullptr)
  32. {
  33. s_instance = std::make_unique<ContextManager>();
  34. }
  35. return s_instance.get();
  36. }
  37. Context* ContextManager::GetContext(int uid) const
  38. {
  39. auto it = s_instance->m_contexts.find(uid);
  40. if (it != s_instance->m_contexts.end())
  41. {
  42. return it->second.get();
  43. }
  44. return nullptr;
  45. }
  46. Context* ContextManager::CreateContext(int uid, ContextDependencies& dependencies)
  47. {
  48. auto it = s_instance->m_contexts.find(uid);
  49. if (it != s_instance->m_contexts.end())
  50. {
  51. DebugLog("Using already created context with ID %d", uid);
  52. return nullptr;
  53. }
  54. s_instance->m_contexts[uid] = std::make_unique<Context>(dependencies);
  55. return s_instance->m_contexts[uid].get();
  56. }
  57. void ContextManager::SetCurContext(Context* context) { curContext = context; }
  58. bool ContextManager::Exists(Context* context)
  59. {
  60. for (auto it = s_instance->m_contexts.begin(); it != s_instance->m_contexts.end(); ++it)
  61. {
  62. if (it->second.get() == context)
  63. return true;
  64. }
  65. return false;
  66. }
  67. void ContextManager::DestroyContext(int uid)
  68. {
  69. auto it = s_instance->m_contexts.find(uid);
  70. if (it != s_instance->m_contexts.end())
  71. {
  72. s_instance->m_contexts.erase(it);
  73. }
  74. }
  75. ContextManager::~ContextManager()
  76. {
  77. if (m_contexts.size())
  78. {
  79. DebugWarning("%lu remaining context(s) registered", m_contexts.size());
  80. }
  81. m_contexts.clear();
  82. }
  83. bool Convert(const std::string& str, PeerConnectionInterface::RTCConfiguration& config)
  84. {
  85. config = PeerConnectionInterface::RTCConfiguration {};
  86. Json::CharReaderBuilder builder;
  87. const std::unique_ptr<Json::CharReader> reader(builder.newCharReader());
  88. Json::Value configJson;
  89. Json::String err;
  90. auto ok = reader->parse(str.c_str(), str.c_str() + static_cast<int>(str.length()), &configJson, &err);
  91. if (!ok)
  92. {
  93. // json parse failed.
  94. return false;
  95. }
  96. Json::Value iceServersJson = configJson["iceServers"];
  97. if (!iceServersJson)
  98. return false;
  99. for (auto iceServerJson : iceServersJson)
  100. {
  101. webrtc::PeerConnectionInterface::IceServer iceServer;
  102. for (auto url : iceServerJson["urls"])
  103. {
  104. iceServer.urls.push_back(url.asString());
  105. }
  106. if (!iceServerJson["username"].isNull())
  107. {
  108. iceServer.username = iceServerJson["username"].asString();
  109. }
  110. if (!iceServerJson["credential"].isNull())
  111. {
  112. iceServer.password = iceServerJson["credential"].asString();
  113. }
  114. config.servers.push_back(iceServer);
  115. }
  116. Json::Value iceTransportPolicy = configJson["iceTransportPolicy"];
  117. if (iceTransportPolicy["hasValue"].asBool())
  118. {
  119. config.type = static_cast<PeerConnectionInterface::IceTransportsType>(iceTransportPolicy["value"].asInt());
  120. }
  121. Json::Value enableDtlsSrtp = configJson["enableDtlsSrtp"];
  122. if (enableDtlsSrtp["hasValue"].asBool())
  123. {
  124. config.enable_dtls_srtp = enableDtlsSrtp["value"].asBool();
  125. }
  126. Json::Value iceCandidatePoolSize = configJson["iceCandidatePoolSize"];
  127. if (iceCandidatePoolSize["hasValue"].asBool())
  128. {
  129. config.ice_candidate_pool_size = iceCandidatePoolSize["value"].asInt();
  130. }
  131. Json::Value bundlePolicy = configJson["bundlePolicy"];
  132. if (bundlePolicy["hasValue"].asBool())
  133. {
  134. config.bundle_policy = static_cast<PeerConnectionInterface::BundlePolicy>(bundlePolicy["value"].asInt());
  135. }
  136. config.sdp_semantics = webrtc::SdpSemantics::kUnifiedPlan;
  137. config.enable_implicit_rollback = true;
  138. return true;
  139. }
  140. Context::Context(ContextDependencies& dependencies)
  141. : m_workerThread(rtc::Thread::CreateWithSocketServer())
  142. , m_signalingThread(rtc::Thread::CreateWithSocketServer())
  143. , m_taskQueueFactory(CreateDefaultTaskQueueFactory())
  144. {
  145. m_workerThread->Start();
  146. m_signalingThread->Start();
  147. rtc::InitializeSSL();
  148. m_audioDevice = m_workerThread->Invoke<rtc::scoped_refptr<DummyAudioDevice>>(
  149. RTC_FROM_HERE, [&]() { return new rtc::RefCountedObject<DummyAudioDevice>(m_taskQueueFactory.get()); });
  150. std::unique_ptr<webrtc::VideoEncoderFactory> videoEncoderFactory =
  151. std::make_unique<UnityVideoEncoderFactory>(dependencies.device, dependencies.profiler);
  152. std::unique_ptr<webrtc::VideoDecoderFactory> videoDecoderFactory =
  153. std::make_unique<UnityVideoDecoderFactory>(dependencies.device, dependencies.profiler);
  154. rtc::scoped_refptr<AudioEncoderFactory> audioEncoderFactory = CreateAudioEncoderFactory();
  155. rtc::scoped_refptr<AudioDecoderFactory> audioDecoderFactory = CreateAudioDecoderFactory();
  156. m_peerConnectionFactory = CreatePeerConnectionFactory(
  157. m_workerThread.get(),
  158. m_workerThread.get(),
  159. m_signalingThread.get(),
  160. m_audioDevice,
  161. audioEncoderFactory,
  162. audioDecoderFactory,
  163. std::move(videoEncoderFactory),
  164. std::move(videoDecoderFactory),
  165. nullptr,
  166. nullptr);
  167. }
  168. Context::~Context()
  169. {
  170. {
  171. std::lock_guard<std::mutex> lock(mutex);
  172. m_peerConnectionFactory = nullptr;
  173. m_workerThread->Invoke<void>(RTC_FROM_HERE, [this]() { m_audioDevice = nullptr; });
  174. m_mapClients.clear();
  175. // check count of refptr to avoid to forget disposing
  176. RTC_DCHECK_EQ(m_mapRefPtr.size(), 0);
  177. m_mapRefPtr.clear();
  178. m_mapMediaStreamObserver.clear();
  179. m_mapDataChannels.clear();
  180. m_mapVideoRenderer.clear();
  181. m_workerThread->Quit();
  182. m_workerThread.reset();
  183. m_signalingThread->Quit();
  184. m_signalingThread.reset();
  185. }
  186. }
  187. webrtc::MediaStreamInterface* Context::CreateMediaStream(const std::string& streamId)
  188. {
  189. rtc::scoped_refptr<webrtc::MediaStreamInterface> stream =
  190. m_peerConnectionFactory->CreateLocalMediaStream(streamId);
  191. AddRefPtr(stream);
  192. return stream;
  193. }
  194. void Context::RegisterMediaStreamObserver(webrtc::MediaStreamInterface* stream)
  195. {
  196. m_mapMediaStreamObserver[stream] = std::make_unique<MediaStreamObserver>(stream, this);
  197. }
  198. void Context::UnRegisterMediaStreamObserver(webrtc::MediaStreamInterface* stream)
  199. {
  200. m_mapMediaStreamObserver.erase(stream);
  201. }
  202. MediaStreamObserver* Context::GetObserver(const webrtc::MediaStreamInterface* stream)
  203. {
  204. return m_mapMediaStreamObserver[stream].get();
  205. }
  206. VideoTrackSourceInterface* Context::CreateVideoSource()
  207. {
  208. const rtc::scoped_refptr<UnityVideoTrackSource> source =
  209. new rtc::RefCountedObject<UnityVideoTrackSource>(false, absl::nullopt, m_taskQueueFactory.get());
  210. AddRefPtr(source);
  211. return source;
  212. }
  213. webrtc::VideoTrackInterface* Context::CreateVideoTrack(const std::string& label, VideoTrackSourceInterface* source)
  214. {
  215. const rtc::scoped_refptr<VideoTrackInterface> track = m_peerConnectionFactory->CreateVideoTrack(label, source);
  216. AddRefPtr(track);
  217. return track;
  218. }
  219. void Context::StopMediaStreamTrack(webrtc::MediaStreamTrackInterface* track)
  220. {
  221. // todo:(kazuki)
  222. }
  223. webrtc::AudioSourceInterface* Context::CreateAudioSource()
  224. {
  225. // avoid optimization specially for voice
  226. cricket::AudioOptions audioOptions;
  227. audioOptions.auto_gain_control = false;
  228. audioOptions.noise_suppression = false;
  229. audioOptions.highpass_filter = false;
  230. const rtc::scoped_refptr<UnityAudioTrackSource> source = UnityAudioTrackSource::Create(audioOptions);
  231. AddRefPtr(source);
  232. return source;
  233. }
  234. AudioTrackInterface* Context::CreateAudioTrack(const std::string& label, webrtc::AudioSourceInterface* source)
  235. {
  236. const rtc::scoped_refptr<AudioTrackInterface> track = m_peerConnectionFactory->CreateAudioTrack(label, source);
  237. AddRefPtr(track);
  238. return track;
  239. }
  240. AudioTrackSinkAdapter* Context::CreateAudioTrackSinkAdapter()
  241. {
  242. auto sink = std::make_unique<AudioTrackSinkAdapter>();
  243. AudioTrackSinkAdapter* ptr = sink.get();
  244. m_mapAudioTrackAndSink.emplace(ptr, std::move(sink));
  245. return ptr;
  246. }
  247. void Context::DeleteAudioTrackSinkAdapter(AudioTrackSinkAdapter* sink) { m_mapAudioTrackAndSink.erase(sink); }
  248. void Context::AddStatsReport(const rtc::scoped_refptr<const webrtc::RTCStatsReport>& report)
  249. {
  250. std::lock_guard<std::mutex> lock(mutexStatsReport);
  251. m_listStatsReport.push_back(report);
  252. }
  253. const RTCStats** Context::GetStatsList(const RTCStatsReport* report, size_t* length, uint32_t** types)
  254. {
  255. std::lock_guard<std::mutex> lock(mutexStatsReport);
  256. auto result = std::find_if(
  257. m_listStatsReport.begin(),
  258. m_listStatsReport.end(),
  259. [report](rtc::scoped_refptr<const webrtc::RTCStatsReport> it) { return it.get() == report; });
  260. if (result == m_listStatsReport.end())
  261. {
  262. RTC_LOG(LS_INFO) << "Calling GetStatsList is failed. The reference of RTCStatsReport is not found.";
  263. return nullptr;
  264. }
  265. const size_t size = report->size();
  266. *length = size;
  267. *types = static_cast<uint32_t*>(CoTaskMemAlloc(sizeof(uint32_t) * size));
  268. void* buf = CoTaskMemAlloc(sizeof(RTCStats*) * size);
  269. const RTCStats** ret = static_cast<const RTCStats**>(buf);
  270. if (size == 0)
  271. {
  272. return ret;
  273. }
  274. int i = 0;
  275. for (const auto& stats : *report)
  276. {
  277. ret[i] = &stats;
  278. (*types)[i] = statsTypes.at(stats.type());
  279. i++;
  280. }
  281. return ret;
  282. }
  283. void Context::DeleteStatsReport(const webrtc::RTCStatsReport* report)
  284. {
  285. std::lock_guard<std::mutex> lock(mutexStatsReport);
  286. auto result = std::find_if(
  287. m_listStatsReport.begin(),
  288. m_listStatsReport.end(),
  289. [report](rtc::scoped_refptr<const webrtc::RTCStatsReport> it) { return it.get() == report; });
  290. if (result == m_listStatsReport.end())
  291. {
  292. RTC_LOG(LS_INFO) << "Calling DeleteStatsReport is failed. The reference of RTCStatsReport is not found.";
  293. return;
  294. }
  295. m_listStatsReport.erase(result);
  296. }
  297. DataChannelInterface*
  298. Context::CreateDataChannel(PeerConnectionObject* obj, const char* label, const DataChannelInit& options)
  299. {
  300. const rtc::scoped_refptr<DataChannelInterface> channel = obj->connection->CreateDataChannel(label, &options);
  301. if (channel == nullptr)
  302. return nullptr;
  303. AddDataChannel(channel, *obj);
  304. return channel;
  305. }
  306. void Context::AddDataChannel(DataChannelInterface* channel, PeerConnectionObject& pc)
  307. {
  308. auto dataChannelObj = std::make_unique<DataChannelObject>(channel, pc);
  309. m_mapDataChannels[channel] = std::move(dataChannelObj);
  310. }
  311. DataChannelObject* Context::GetDataChannelObject(const DataChannelInterface* channel)
  312. {
  313. return m_mapDataChannels[channel].get();
  314. }
  315. void Context::DeleteDataChannel(DataChannelInterface* channel)
  316. {
  317. if (m_mapDataChannels.count(channel) > 0)
  318. {
  319. m_mapDataChannels.erase(channel);
  320. }
  321. }
  322. PeerConnectionObject* Context::CreatePeerConnection(const webrtc::PeerConnectionInterface::RTCConfiguration& config)
  323. {
  324. std::unique_ptr<PeerConnectionObject> obj = std::make_unique<PeerConnectionObject>(*this);
  325. PeerConnectionDependencies dependencies(obj.get());
  326. auto connection = m_peerConnectionFactory->CreatePeerConnectionOrError(config, std::move(dependencies));
  327. if (!connection.ok())
  328. {
  329. RTC_LOG(LS_ERROR) << connection.error().message();
  330. return nullptr;
  331. }
  332. obj->connection = connection.MoveValue();
  333. PeerConnectionObject* ptr = obj.get();
  334. m_mapClients[ptr] = std::move(obj);
  335. return ptr;
  336. }
  337. void Context::DeletePeerConnection(PeerConnectionObject* obj) { m_mapClients.erase(obj); }
  338. uint32_t Context::s_rendererId = 0;
  339. uint32_t Context::GenerateRendererId() { return s_rendererId++; }
  340. UnityVideoRenderer* Context::CreateVideoRenderer(DelegateVideoFrameResize callback, bool needFlipVertical)
  341. {
  342. auto rendererId = GenerateRendererId();
  343. auto renderer = std::make_shared<UnityVideoRenderer>(rendererId, callback, needFlipVertical);
  344. m_mapVideoRenderer[rendererId] = renderer;
  345. return m_mapVideoRenderer[rendererId].get();
  346. }
  347. std::shared_ptr<UnityVideoRenderer> Context::GetVideoRenderer(uint32_t id) { return m_mapVideoRenderer[id]; }
  348. void Context::DeleteVideoRenderer(UnityVideoRenderer* renderer)
  349. {
  350. m_mapVideoRenderer.erase(renderer->GetId());
  351. renderer = nullptr;
  352. }
  353. void Context::GetRtpSenderCapabilities(cricket::MediaType kind, RtpCapabilities* capabilities) const
  354. {
  355. *capabilities = m_peerConnectionFactory->GetRtpSenderCapabilities(kind);
  356. }
  357. void Context::GetRtpReceiverCapabilities(cricket::MediaType kind, RtpCapabilities* capabilities) const
  358. {
  359. *capabilities = m_peerConnectionFactory->GetRtpReceiverCapabilities(kind);
  360. }
  361. } // end namespace webrtc
  362. } // end namespace unity