// Copyright 2016 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. //#include "DispatcherBase.h" //#include "Parser.h" {% for namespace in config.protocol.namespace %} namespace {{namespace}} { {% endfor %} // static DispatchResponse DispatchResponse::OK() { DispatchResponse result; result.m_status = kSuccess; result.m_errorCode = kParseError; return result; } // static DispatchResponse DispatchResponse::Error(const String& error) { DispatchResponse result; result.m_status = kError; result.m_errorCode = kServerError; result.m_errorMessage = error; return result; } // static DispatchResponse DispatchResponse::InternalError() { DispatchResponse result; result.m_status = kError; result.m_errorCode = kInternalError; result.m_errorMessage = "Internal error"; return result; } // static DispatchResponse DispatchResponse::InvalidParams(const String& error) { DispatchResponse result; result.m_status = kError; result.m_errorCode = kInvalidParams; result.m_errorMessage = error; return result; } // static DispatchResponse DispatchResponse::FallThrough() { DispatchResponse result; result.m_status = kFallThrough; result.m_errorCode = kParseError; return result; } // static const char DispatcherBase::kInvalidParamsString[] = "Invalid parameters"; DispatcherBase::WeakPtr::WeakPtr(DispatcherBase* dispatcher) : m_dispatcher(dispatcher) { } DispatcherBase::WeakPtr::~WeakPtr() { if (m_dispatcher) m_dispatcher->m_weakPtrs.erase(this); } DispatcherBase::Callback::Callback(std::unique_ptr backendImpl, int callId, int callbackId) : m_backendImpl(std::move(backendImpl)) , m_callId(callId) , m_callbackId(callbackId) { } DispatcherBase::Callback::~Callback() = default; void DispatcherBase::Callback::dispose() { m_backendImpl = nullptr; } void DispatcherBase::Callback::sendIfActive(std::unique_ptr partialMessage, const DispatchResponse& response) { if (!m_backendImpl || !m_backendImpl->get()) return; m_backendImpl->get()->sendResponse(m_callId, response, std::move(partialMessage)); m_backendImpl = nullptr; } void DispatcherBase::Callback::fallThroughIfActive() { if (!m_backendImpl || !m_backendImpl->get()) return; m_backendImpl->get()->markFallThrough(m_callbackId); m_backendImpl = nullptr; } DispatcherBase::DispatcherBase(FrontendChannel* frontendChannel) : m_frontendChannel(frontendChannel) , m_lastCallbackId(0) , m_lastCallbackFallThrough(false) { } DispatcherBase::~DispatcherBase() { clearFrontend(); } int DispatcherBase::nextCallbackId() { m_lastCallbackFallThrough = false; return ++m_lastCallbackId; } void DispatcherBase::markFallThrough(int callbackId) { DCHECK(callbackId == m_lastCallbackId); m_lastCallbackFallThrough = true; } // static bool DispatcherBase::getCommandName(const String& message, String* result) { std::unique_ptr value = StringUtil::parseJSON(message); if (!value) return false; protocol::DictionaryValue* object = DictionaryValue::cast(value.get()); if (!object) return false; if (!object->getString("method", result)) return false; return true; } void DispatcherBase::sendResponse(int callId, const DispatchResponse& response, std::unique_ptr result) { if (response.status() == DispatchResponse::kError) { reportProtocolError(callId, response.errorCode(), response.errorMessage(), nullptr); return; } std::unique_ptr responseMessage = DictionaryValue::create(); responseMessage->setInteger("id", callId); responseMessage->setObject("result", std::move(result)); if (m_frontendChannel) m_frontendChannel->sendProtocolResponse(callId, responseMessage->toJSONString()); } void DispatcherBase::sendResponse(int callId, const DispatchResponse& response) { sendResponse(callId, response, DictionaryValue::create()); } static void reportProtocolErrorTo(FrontendChannel* frontendChannel, int callId, DispatchResponse::ErrorCode code, const String& errorMessage, ErrorSupport* errors) { if (!frontendChannel) return; std::unique_ptr error = DictionaryValue::create(); error->setInteger("code", code); error->setString("message", errorMessage); if (errors && errors->hasErrors()) error->setString("data", errors->errors()); std::unique_ptr message = DictionaryValue::create(); message->setObject("error", std::move(error)); message->setInteger("id", callId); frontendChannel->sendProtocolResponse(callId, message->toJSONString()); } static void reportProtocolErrorTo(FrontendChannel* frontendChannel, DispatchResponse::ErrorCode code, const String& errorMessage) { if (!frontendChannel) return; std::unique_ptr error = DictionaryValue::create(); error->setInteger("code", code); error->setString("message", errorMessage); std::unique_ptr message = DictionaryValue::create(); message->setObject("error", std::move(error)); frontendChannel->sendProtocolNotification(message->toJSONString()); } void DispatcherBase::reportProtocolError(int callId, DispatchResponse::ErrorCode code, const String& errorMessage, ErrorSupport* errors) { reportProtocolErrorTo(m_frontendChannel, callId, code, errorMessage, errors); } void DispatcherBase::clearFrontend() { m_frontendChannel = nullptr; for (auto& weak : m_weakPtrs) weak->dispose(); m_weakPtrs.clear(); } std::unique_ptr DispatcherBase::weakPtr() { std::unique_ptr weak(new DispatcherBase::WeakPtr(this)); m_weakPtrs.insert(weak.get()); return weak; } UberDispatcher::UberDispatcher(FrontendChannel* frontendChannel) : m_frontendChannel(frontendChannel) , m_fallThroughForNotFound(false) { } void UberDispatcher::setFallThroughForNotFound(bool fallThroughForNotFound) { m_fallThroughForNotFound = fallThroughForNotFound; } void UberDispatcher::registerBackend(const String& name, std::unique_ptr dispatcher) { m_dispatchers[name] = std::move(dispatcher); } DispatchResponse::Status UberDispatcher::dispatch(std::unique_ptr parsedMessage) { if (!parsedMessage) { reportProtocolErrorTo(m_frontendChannel, DispatchResponse::kParseError, "Message must be a valid JSON"); return DispatchResponse::kError; } std::unique_ptr messageObject = DictionaryValue::cast(std::move(parsedMessage)); if (!messageObject) { reportProtocolErrorTo(m_frontendChannel, DispatchResponse::kInvalidRequest, "Message must be an object"); return DispatchResponse::kError; } int callId = 0; protocol::Value* callIdValue = messageObject->get("id"); bool success = callIdValue && callIdValue->asInteger(&callId); if (!success) { reportProtocolErrorTo(m_frontendChannel, DispatchResponse::kInvalidRequest, "Message must have integer 'id' porperty"); return DispatchResponse::kError; } protocol::Value* methodValue = messageObject->get("method"); String method; success = methodValue && methodValue->asString(&method); if (!success) { reportProtocolErrorTo(m_frontendChannel, callId, DispatchResponse::kInvalidRequest, "Message must have string 'method' porperty", nullptr); return DispatchResponse::kError; } size_t dotIndex = method.find("."); if (dotIndex == StringUtil::kNotFound) { if (m_fallThroughForNotFound) return DispatchResponse::kFallThrough; reportProtocolErrorTo(m_frontendChannel, callId, DispatchResponse::kMethodNotFound, "'" + method + "' wasn't found", nullptr); return DispatchResponse::kError; } String domain = StringUtil::substring(method, 0, dotIndex); auto it = m_dispatchers.find(domain); if (it == m_dispatchers.end()) { if (m_fallThroughForNotFound) return DispatchResponse::kFallThrough; reportProtocolErrorTo(m_frontendChannel, callId, DispatchResponse::kMethodNotFound, "'" + method + "' wasn't found", nullptr); return DispatchResponse::kError; } return it->second->dispatch(callId, method, std::move(messageObject)); } UberDispatcher::~UberDispatcher() = default; {% for namespace in config.protocol.namespace %} } // namespace {{namespace}} {% endfor %}