diff --git a/hphp/runtime/base/server/satellite_server.cpp b/hphp/runtime/base/server/satellite_server.cpp index 821bed607..55dbea5da 100644 --- a/hphp/runtime/base/server/satellite_server.cpp +++ b/hphp/runtime/base/server/satellite_server.cpp @@ -70,46 +70,33 @@ SatelliteServerInfo::SatelliteServerInfo(Hdf hdf) { } } +bool SatelliteServerInfo::checkMainURL(const std::string& path) { + String url(path.c_str(), path.size(), AttachLiteral); + for (std::set::const_iterator iter = + SatelliteServerInfo::InternalURLs.begin(); + iter != SatelliteServerInfo::InternalURLs.end(); ++iter) { + Variant ret = preg_match + (String(iter->c_str(), iter->size(), AttachLiteral), url); + if (ret.toInt64() > 0) { + return false; + } + } + return true; +} + /////////////////////////////////////////////////////////////////////////////// // InternalPageServer: LibEventServer + allowed URL checking -DECLARE_BOOST_TYPES(InternalPageServerImpl); -class InternalPageServerImpl : public LibEventServer { -public: - InternalPageServerImpl(const std::string &address, int port, int thread, - int timeoutSeconds) : - LibEventServer(address, port, thread, timeoutSeconds) { - } - void create(const std::set &urls) { - m_allowedURLs = urls; - } - - virtual bool shouldHandle(const std::string &cmd) { - String url(cmd.c_str(), cmd.size(), AttachLiteral); - for (set::const_iterator iter = m_allowedURLs.begin(); - iter != m_allowedURLs.end(); ++iter) { - Variant ret = preg_match - (String(iter->c_str(), iter->size(), AttachLiteral), url); - if (ret.toInt64() > 0) { - return true; - } - } - return false; - } - -private: - std::set m_allowedURLs; -}; - class InternalPageServer : public SatelliteServer { public: - explicit InternalPageServer(SatelliteServerInfoPtr info) { - auto const server = boost::make_shared( + explicit InternalPageServer(SatelliteServerInfoPtr info) + : m_allowedURLs(info->getURLs()) { + m_server = boost::make_shared( RuntimeOption::ServerIP, info->getPort(), info->getThreadCount(), info->getTimeoutSeconds()); - server->setRequestHandlerFactory(); - server->create(info->getURLs()); - m_server = server; + m_server->setRequestHandlerFactory(); + m_server->setUrlChecker(std::bind(&InternalPageServer::checkURL, this, + std::placeholders::_1)); } virtual void start() { @@ -119,8 +106,22 @@ public: m_server->stop(); m_server->waitForEnd(); } + private: + bool checkURL(const std::string &path) const { + String url(path.c_str(), path.size(), AttachLiteral); + for (const auto &allowed : m_allowedURLs) { + Variant ret = preg_match + (String(allowed.c_str(), allowed.size(), AttachLiteral), url); + if (ret.toInt64() > 0) { + return true; + } + } + return false; + } + ServerPtr m_server; + std::set m_allowedURLs; }; /////////////////////////////////////////////////////////////////////////////// diff --git a/hphp/runtime/base/server/satellite_server.h b/hphp/runtime/base/server/satellite_server.h index a542d7ac7..6ec379e29 100644 --- a/hphp/runtime/base/server/satellite_server.h +++ b/hphp/runtime/base/server/satellite_server.h @@ -62,8 +62,13 @@ public: static std::set InternalURLs; static int DanglingServerPort; + /** + * Check whether a requested path should be allowed on the main server. + */ + static bool checkMainURL(const std::string& path); + public: - SatelliteServerInfo(Hdf hdf); + explicit SatelliteServerInfo(Hdf hdf); const std::string &getName() const { return m_name;} SatelliteServer::Type getType() const { return m_type;} diff --git a/hphp/runtime/base/server/server.cpp b/hphp/runtime/base/server/server.cpp index 0bd5cd7ef..42b369628 100644 --- a/hphp/runtime/base/server/server.cpp +++ b/hphp/runtime/base/server/server.cpp @@ -14,8 +14,8 @@ +----------------------------------------------------------------------+ */ -#include "hphp/runtime/base/complex_types.h" #include "hphp/runtime/base/server/server.h" +#include "hphp/runtime/base/complex_types.h" #include "hphp/runtime/base/server/satellite_server.h" #include "hphp/runtime/base/preg.h" #include @@ -50,22 +50,9 @@ void Server::InstallStopSignalHandlers(ServerPtr server) { Server::Server(const std::string &address, int port, int threadCount) : m_address(address), m_port(port), m_threadCount(threadCount), + m_urlChecker(SatelliteServerInfo::checkMainURL), m_status(NOT_YET_STARTED) { } -bool Server::shouldHandle(const std::string &cmd) { - String url(cmd.c_str(), cmd.size(), AttachLiteral); - for (std::set::const_iterator iter = - SatelliteServerInfo::InternalURLs.begin(); - iter != SatelliteServerInfo::InternalURLs.end(); ++iter) { - Variant ret = preg_match - (String(iter->c_str(), iter->size(), AttachLiteral), url); - if (ret.toInt64() > 0) { - return false; - } - } - return true; -} - /////////////////////////////////////////////////////////////////////////////// } diff --git a/hphp/runtime/base/server/server.h b/hphp/runtime/base/server/server.h index 5a1cb63fb..5c8bdb38c 100644 --- a/hphp/runtime/base/server/server.h +++ b/hphp/runtime/base/server/server.h @@ -86,6 +86,7 @@ public: }; typedef std::function()> RequestHandlerFactory; +typedef std::function URLChecker; /** * Base class of an HTTP server. Defining minimal interface an HTTP server @@ -135,6 +136,16 @@ public: }); } + /** + * Set the URLChecker function which determines which paths this server is + * allowed to server. + * + * Defaults to SatelliteServerInfo::checkURL() + */ + void setUrlChecker(const URLChecker& checker) { + m_urlChecker = checker; + } + /** * Informational. */ @@ -192,9 +203,11 @@ public: } /** - * Overwrite for URL blocking. + * Check whether a request to the specified server path is allowed. */ - virtual bool shouldHandle(const std::string &cmd); + bool shouldHandle(const std::string &path) { + return m_urlChecker(path); + } /** * To enable SSL of the current server, it will listen to an additional @@ -208,6 +221,7 @@ protected: int m_threadCount; mutable Mutex m_mutex; RequestHandlerFactory m_handlerFactory; + URLChecker m_urlChecker; private: RunStatus m_status;