// Copyright 2016 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. //#include "DispatcherBase.h" //#include "Parser.h" {% for namespace in config.protocol.namespace %} namespace {{namespace}} { {% endfor %} // static DispatchResponse DispatchResponse::OK() { DispatchResponse result; result.m_status = kSuccess; result.m_errorCode = kParseError; return result; } // static DispatchResponse DispatchResponse::Error(const String& error) { DispatchResponse result; result.m_status = kError; result.m_errorCode = kServerError; result.m_errorMessage = error; return result; } // static DispatchResponse DispatchResponse::InternalError() { DispatchResponse result; result.m_status = kError; result.m_errorCode = kInternalError; result.m_errorMessage = "Internal error"; return result; } // static DispatchResponse DispatchResponse::InvalidParams(const String& error) { DispatchResponse result; result.m_status = kError; result.m_errorCode = kInvalidParams; result.m_errorMessage = error; return result; } // static DispatchResponse DispatchResponse::FallThrough() { DispatchResponse result; result.m_status = kFallThrough; result.m_errorCode = kParseError; return result; } // static const char DispatcherBase::kInvalidParamsString[] = "Invalid parameters"; DispatcherBase::WeakPtr::WeakPtr(DispatcherBase* dispatcher) : m_dispatcher(dispatcher) { } DispatcherBase::WeakPtr::~WeakPtr() { if (m_dispatcher) m_dispatcher->m_weakPtrs.erase(this); } DispatcherBase::Callback::Callback(std::unique_ptr backendImpl, int callId, int callbackId) : m_backendImpl(std::move(backendImpl)) , m_callId(callId) , m_callbackId(callbackId) { } DispatcherBase::Callback::~Callback() = default; void DispatcherBase::Callback::dispose() { m_backendImpl = nullptr; } void DispatcherBase::Callback::sendIfActive(std::unique_ptr partialMessage, const DispatchResponse& response) { if (!m_backendImpl || !m_backendImpl->get()) return; m_backendImpl->get()->sendResponse(m_callId, response, std::move(partialMessage)); m_backendImpl = nullptr; } void DispatcherBase::Callback::fallThroughIfActive() { if (!m_backendImpl || !m_backendImpl->get()) return; m_backendImpl->get()->markFallThrough(m_callbackId); m_backendImpl = nullptr; } DispatcherBase::DispatcherBase(FrontendChannel* frontendChannel) : m_frontendChannel(frontendChannel) , m_lastCallbackId(0) , m_lastCallbackFallThrough(false) { } DispatcherBase::~DispatcherBase() { clearFrontend(); } int DispatcherBase::nextCallbackId() { m_lastCallbackFallThrough = false; return ++m_lastCallbackId; } void DispatcherBase::markFallThrough(int callbackId) { DCHECK(callbackId == m_lastCallbackId); m_lastCallbackFallThrough = true; } // static bool DispatcherBase::getCommandName(const String& message, String* result) { std::unique_ptr value = StringUtil::parseJSON(message); if (!value) return false; protocol::DictionaryValue* object = DictionaryValue::cast(value.get()); if (!object) return false; if (!object->getString("method", result)) return false; return true; } void DispatcherBase::sendResponse(int callId, const DispatchResponse& response, std::unique_ptr result) { if (!m_frontendChannel) return; if (response.status() == DispatchResponse::kError) { reportProtocolError(callId, response.errorCode(), response.errorMessage(), nullptr); return; } m_frontendChannel->sendProtocolResponse(callId, InternalResponse::createResponse(callId, std::move(result))); } void DispatcherBase::sendResponse(int callId, const DispatchResponse& response) { sendResponse(callId, response, DictionaryValue::create()); } namespace { class ProtocolError : public Serializable { public: static std::unique_ptr createErrorResponse(int callId, DispatchResponse::ErrorCode code, const String& errorMessage, ErrorSupport* errors) { std::unique_ptr protocolError(new ProtocolError(code, errorMessage)); protocolError->m_callId = callId; protocolError->m_hasCallId = true; if (errors && errors->hasErrors()) protocolError->m_data = errors->errors(); return protocolError; } static std::unique_ptr createErrorNotification(DispatchResponse::ErrorCode code, const String& errorMessage) { return std::unique_ptr(new ProtocolError(code, errorMessage)); } String serialize() override { std::unique_ptr error = DictionaryValue::create(); error->setInteger("code", m_code); error->setString("message", m_errorMessage); if (m_data.length()) error->setString("data", m_data); std::unique_ptr message = DictionaryValue::create(); message->setObject("error", std::move(error)); if (m_hasCallId) message->setInteger("id", m_callId); return message->serialize(); } ~ProtocolError() override {} private: ProtocolError(DispatchResponse::ErrorCode code, const String& errorMessage) : m_code(code) , m_errorMessage(errorMessage) { } DispatchResponse::ErrorCode m_code; String m_errorMessage; String m_data; int m_callId = 0; bool m_hasCallId = false; }; } // namespace static void reportProtocolErrorTo(FrontendChannel* frontendChannel, int callId, DispatchResponse::ErrorCode code, const String& errorMessage, ErrorSupport* errors) { if (frontendChannel) frontendChannel->sendProtocolResponse(callId, ProtocolError::createErrorResponse(callId, code, errorMessage, errors)); } static void reportProtocolErrorTo(FrontendChannel* frontendChannel, DispatchResponse::ErrorCode code, const String& errorMessage) { if (frontendChannel) frontendChannel->sendProtocolNotification(ProtocolError::createErrorNotification(code, errorMessage)); } void DispatcherBase::reportProtocolError(int callId, DispatchResponse::ErrorCode code, const String& errorMessage, ErrorSupport* errors) { reportProtocolErrorTo(m_frontendChannel, callId, code, errorMessage, errors); } void DispatcherBase::clearFrontend() { m_frontendChannel = nullptr; for (auto& weak : m_weakPtrs) weak->dispose(); m_weakPtrs.clear(); } std::unique_ptr DispatcherBase::weakPtr() { std::unique_ptr weak(new DispatcherBase::WeakPtr(this)); m_weakPtrs.insert(weak.get()); return weak; } UberDispatcher::UberDispatcher(FrontendChannel* frontendChannel) : m_frontendChannel(frontendChannel) , m_fallThroughForNotFound(false) { } void UberDispatcher::setFallThroughForNotFound(bool fallThroughForNotFound) { m_fallThroughForNotFound = fallThroughForNotFound; } void UberDispatcher::registerBackend(const String& name, std::unique_ptr dispatcher) { m_dispatchers[name] = std::move(dispatcher); } DispatchResponse::Status UberDispatcher::dispatch(std::unique_ptr parsedMessage) { if (!parsedMessage) { reportProtocolErrorTo(m_frontendChannel, DispatchResponse::kParseError, "Message must be a valid JSON"); return DispatchResponse::kError; } std::unique_ptr messageObject = DictionaryValue::cast(std::move(parsedMessage)); if (!messageObject) { reportProtocolErrorTo(m_frontendChannel, DispatchResponse::kInvalidRequest, "Message must be an object"); return DispatchResponse::kError; } int callId = 0; protocol::Value* callIdValue = messageObject->get("id"); bool success = callIdValue && callIdValue->asInteger(&callId); if (!success) { reportProtocolErrorTo(m_frontendChannel, DispatchResponse::kInvalidRequest, "Message must have integer 'id' porperty"); return DispatchResponse::kError; } protocol::Value* methodValue = messageObject->get("method"); String method; success = methodValue && methodValue->asString(&method); if (!success) { reportProtocolErrorTo(m_frontendChannel, callId, DispatchResponse::kInvalidRequest, "Message must have string 'method' porperty", nullptr); return DispatchResponse::kError; } size_t dotIndex = StringUtil::find(method, "."); if (dotIndex == StringUtil::kNotFound) { if (m_fallThroughForNotFound) return DispatchResponse::kFallThrough; reportProtocolErrorTo(m_frontendChannel, callId, DispatchResponse::kMethodNotFound, "'" + method + "' wasn't found", nullptr); return DispatchResponse::kError; } String domain = StringUtil::substring(method, 0, dotIndex); auto it = m_dispatchers.find(domain); if (it == m_dispatchers.end()) { if (m_fallThroughForNotFound) return DispatchResponse::kFallThrough; reportProtocolErrorTo(m_frontendChannel, callId, DispatchResponse::kMethodNotFound, "'" + method + "' wasn't found", nullptr); return DispatchResponse::kError; } return it->second->dispatch(callId, method, std::move(messageObject)); } UberDispatcher::~UberDispatcher() = default; // static std::unique_ptr InternalResponse::createResponse(int callId, std::unique_ptr params) { return std::unique_ptr(new InternalResponse(callId, String(), std::move(params))); } // static std::unique_ptr InternalResponse::createNotification(const String& notification, std::unique_ptr params) { return std::unique_ptr(new InternalResponse(0, notification, std::move(params))); } String InternalResponse::serialize() { std::unique_ptr result = DictionaryValue::create(); std::unique_ptr params(m_params ? std::move(m_params) : DictionaryValue::create()); if (m_notification.length()) { result->setString("method", m_notification); result->setValue("params", SerializedValue::create(params->serialize())); } else { result->setInteger("id", m_callId); result->setValue("result", SerializedValue::create(params->serialize())); } return result->serialize(); } InternalResponse::InternalResponse(int callId, const String& notification, std::unique_ptr params) : m_callId(callId) , m_notification(notification) , m_params(params ? std::move(params) : nullptr) { } {% for namespace in config.protocol.namespace %} } // namespace {{namespace}} {% endfor %}