diff --git a/Client/Auth/WebsocketAuthenticationProvider.php b/Client/Auth/WebsocketAuthenticationProvider.php index 880fdbcc..bd3e4e75 100644 --- a/Client/Auth/WebsocketAuthenticationProvider.php +++ b/Client/Auth/WebsocketAuthenticationProvider.php @@ -77,7 +77,11 @@ protected function getToken(ConnectionInterface $connection) } if (null === $token) { - $token = new AnonymousToken($this->firewalls[0], 'anon-' . $connection->WAMP->sessionId); + if (!$this->tokenStorage->getToken()) { + $token = new AnonymousToken($this->firewalls[0], 'anon-' . $connection->WAMP->sessionId); + } else { + $token = $this->tokenStorage->getToken(); + } } if ($this->tokenStorage->getToken() !== $token) { diff --git a/DependencyInjection/Configuration.php b/DependencyInjection/Configuration.php index e4404a41..a8dd3f54 100644 --- a/DependencyInjection/Configuration.php +++ b/DependencyInjection/Configuration.php @@ -95,6 +95,11 @@ public function getConfigTreeBuilder() ->end() ->end() ->end() + ->arrayNode('handshake_middleware') + ->prototype('scalar') + ->example('@some_service # have to extends Gos\Bundle\WebSocketBundle\Server\App\Stack\HandshakeMiddlewareAbstract') + ->end() + ->end() ->end() ->end() ->arrayNode('rpc') diff --git a/DependencyInjection/GosWebSocketExtension.php b/DependencyInjection/GosWebSocketExtension.php index 48a7e2dc..aa073d24 100644 --- a/DependencyInjection/GosWebSocketExtension.php +++ b/DependencyInjection/GosWebSocketExtension.php @@ -64,6 +64,16 @@ public function load(array $configs, ContainerBuilder $container) } } + if (!empty($configs['server']['handshake_middleware'])) { + $HandshakeMiddlewareRegistryDef = $container->getDefinition('gos_web_socket.handshake_middleware.registry'); + + foreach ($configs['server']['handshake_middleware'] as $middleware) { + $HandshakeMiddlewareRegistryDef->addMethodCall('addMiddleware', [new Reference(ltrim($middleware, '@'))]); + } + } + + + $container->setParameter('web_socket_server.client_storage.ttl', $configs['client']['storage']['ttl']); $container->setParameter('web_socket_server.client_storage.prefix', $configs['client']['storage']['prefix']); diff --git a/Event/ClientEventListener.php b/Event/ClientEventListener.php index c982a7ad..01c2b228 100644 --- a/Event/ClientEventListener.php +++ b/Event/ClientEventListener.php @@ -3,6 +3,7 @@ namespace Gos\Bundle\WebSocketBundle\Event; use Gos\Bundle\WebSocketBundle\Client\Auth\WebsocketAuthenticationProvider; +use Gos\Bundle\WebSocketBundle\Client\Auth\WebsocketAuthenticationProviderInterface; use Gos\Bundle\WebSocketBundle\Client\ClientStorageInterface; use Gos\Bundle\WebSocketBundle\Client\Exception\ClientNotFoundException; use Psr\Log\LoggerInterface; @@ -31,12 +32,12 @@ class ClientEventListener /** * @param ClientStorageInterface $clientStorage - * @param WebsocketAuthenticationProvider $authenticationProvider + * @param WebsocketAuthenticationProviderInterface $authenticationProvider * @param LoggerInterface|null $logger */ public function __construct( ClientStorageInterface $clientStorage, - WebsocketAuthenticationProvider $authenticationProvider, + WebsocketAuthenticationProviderInterface $authenticationProvider, LoggerInterface $logger = null ) { $this->clientStorage = $clientStorage; @@ -124,8 +125,6 @@ public function onClientError(ClientErrorEvent $event) */ public function onClientRejected(ClientRejectedEvent $event) { - $this->logger->warning('Client rejected, bad origin', [ - 'origin' => $event->getOrigin(), - ]); + $this->logger->warning('Client rejected, ' . $event->getMsg()); } } diff --git a/Event/ClientRejectedEvent.php b/Event/ClientRejectedEvent.php index 7963b4e0..f2f3e951 100644 --- a/Event/ClientRejectedEvent.php +++ b/Event/ClientRejectedEvent.php @@ -13,7 +13,7 @@ class ClientRejectedEvent extends Event /** * @var string */ - protected $origin; + protected $msg; /** * @var RequestInterface @@ -21,21 +21,21 @@ class ClientRejectedEvent extends Event protected $request; /** - * @param string $origin + * @param string $msg * @param RequestInterface $request */ - public function __construct($origin, RequestInterface $request = null) + public function __construct($msg, RequestInterface $request = null) { - $this->origin = $origin; + $this->msg = $msg; $this->request = $request; } /** * @return string */ - public function getOrigin() + public function getMsg() { - return $this->origin; + return $this->msg; } /** diff --git a/README.md b/README.md index a276e4b7..e076929d 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,7 @@ Resources * [Performance Bench](Resources/docs/Performance.md) * [Push integration](Resources/docs/Pusher.md) * [SSL configuration](Resources/docs/Ssl.md) +* [Handshake Middleware For Server](Resources/docs/HandshakeMiddleware.md) Code Cookbook -------------- diff --git a/Resources/config/services/services.yml b/Resources/config/services/services.yml index b8c5903e..206e4a65 100644 --- a/Resources/config/services/services.yml +++ b/Resources/config/services/services.yml @@ -51,6 +51,7 @@ services: - '%web_socket_origin_check%' - '@gos_web_socket.wamp.topic_manager' - '@gos_web_socket.server_push_handler.registry' + - '@gos_web_socket.handshake_middleware.registry' - '@?monolog.logger.websocket' tags: - { name: gos_web_socket.server } @@ -85,6 +86,9 @@ services: gos_web_socket.origins.registry: class: Gos\Bundle\WebSocketBundle\Server\App\Registry\OriginRegistry + gos_web_socket.handshake_middleware.registry: + class: Gos\Bundle\WebSocketBundle\Server\App\Registry\HandshakeMiddlewareRegistry + gos_web_socket_server.wamp_application: class: Gos\Bundle\WebSocketBundle\Server\App\WampApplication public: false diff --git a/Resources/docs/ConfigurationReference.md b/Resources/docs/ConfigurationReference.md index 7d46fcca..5b853e63 100644 --- a/Resources/docs/ConfigurationReference.md +++ b/Resources/docs/ConfigurationReference.md @@ -19,6 +19,9 @@ gos_web_socket: - @AcmeBundle/Resources/config/pubsub/routing.yml context: tokenSeparator: "/" + handshake_middleware: [] + # - @some_service + rpc: [] topics: [] periodic: [] diff --git a/Resources/docs/Events.md b/Resources/docs/Events.md index ef6af821..754fc935 100644 --- a/Resources/docs/Events.md +++ b/Resources/docs/Events.md @@ -81,9 +81,9 @@ class AcmeClientEventListener */ public function onClientRejected(ClientRejectedEvent $event) { - $origin = $event->getOrigin; + $msg = $event->getMsg(); - echo 'connection rejected from '. $origin . PHP_EOL; + echo $msg . PHP_EOL; } } ``` diff --git a/Resources/docs/HandshakeMiddleware.md b/Resources/docs/HandshakeMiddleware.md new file mode 100644 index 00000000..5b07e24f --- /dev/null +++ b/Resources/docs/HandshakeMiddleware.md @@ -0,0 +1,126 @@ +# HandshakeMiddleware + +You can add any middleware as service to server with your business logic + + +**Bundle Configuration** + +```yaml +# Gos Web Socket Bundle +gos_web_socket: + server: + handshake_middleware: + - @some_service # have to extends Gos\Bundle\WebSocketBundle\Server\App\Stack\HandshakeMiddlewareAbstract +``` + + + +### Handshake middleware example for OAuth + +```php +oAuthService = $oAuthService; + $this->eventDispatcher = $eventDispatcher; + $this->firewalls = $firewalls; + $this->tokenStorage = $tokenStorage; + } + + /** + * @param ConnectionInterface $conn + * @param RequestInterface|null $request + * + * @return void + */ + public function onOpen(ConnectionInterface $conn, RequestInterface $request = null) + { + try { + $accessToken = $this->oAuthService->verifyAccessToken($request->getQuery()->get('access_token')); + } catch (OAuth2AuthenticateException $e) { + $this->eventDispatcher->dispatch( + Events::CLIENT_REJECTED, + new ClientRejectedEvent($e->getMessage(), $request) + ); + + $this->close($conn, 403); + return ; + } + + $user = $accessToken->getUser(); + $token = new AnonymousToken( + $request->getQuery()->get('access_token'), + $user, + $user->getRoles() + ); + $this->tokenStorage->setToken($token); + + return $this->_component->onOpen($conn, $request); + } + + /** + * Close a connection with an HTTP response. + * + * @param \Ratchet\ConnectionInterface $conn + * @param int $code HTTP status code + */ + protected function close(ConnectionInterface $conn, $code = 400) + { + $response = new Response($code, [ + 'X-Powered-By' => \Ratchet\VERSION, + ]); + + $conn->send((string)$response); + $conn->close(); + } +} +``` \ No newline at end of file diff --git a/Server/App/Registry/HandshakeMiddlewareRegistry.php b/Server/App/Registry/HandshakeMiddlewareRegistry.php new file mode 100644 index 00000000..a040e4e4 --- /dev/null +++ b/Server/App/Registry/HandshakeMiddlewareRegistry.php @@ -0,0 +1,37 @@ + + */ +class HandshakeMiddlewareRegistry +{ + /** + * @var HandshakeMiddlewareAbstract[] + */ + protected $middlewares; + + public function __construct() + { + $this->middlewares = []; + } + + /** + * @param HandshakeMiddlewareAbstract $middleware + */ + public function addMiddleware(HandshakeMiddlewareAbstract $middleware) + { + $this->middlewares[] = $middleware; + } + + /** + * @return HandshakeMiddlewareAbstract[] + */ + public function getMiddlewares() + { + return $this->middlewares; + } +} diff --git a/Server/App/Stack/Factory/Middleware.php b/Server/App/Stack/Factory/Middleware.php new file mode 100644 index 00000000..5a602c59 --- /dev/null +++ b/Server/App/Stack/Factory/Middleware.php @@ -0,0 +1,73 @@ + + */ +class Middleware implements HttpServerInterface +{ + /** + * @var \Ratchet\MessageComponentInterface + */ + protected $_component; + + /** + * @var HandshakeMiddlewareAbstract + */ + protected $_middleware; + + /** + * @param MessageComponentInterface $component + * @param HandshakeMiddlewareAbstract $middleware + */ + public function __construct( + MessageComponentInterface $component, + HandshakeMiddlewareAbstract $middleware + ) { + $this->_component = $component; + $this->_middleware = $middleware; + $this->_middleware->setComponent($component); + } + + /** + * {@inheritdoc} + */ + public function onOpen(ConnectionInterface $conn, RequestInterface $request = null) + { + return $this->_middleware->onOpen($conn, $request); + } + + /** + * {@inheritdoc} + */ + public function onMessage(ConnectionInterface $from, $msg) + { + return $this->_middleware->onMessage($from, $msg); + } + + /** + * {@inheritdoc} + */ + public function onClose(ConnectionInterface $conn) + { + return $this->_middleware->onClose($conn); + } + + /** + * {@inheritdoc} + */ + public function onError(ConnectionInterface $conn, \Exception $e) + { + return $this->_middleware->onError($conn, $e); + } +} diff --git a/Server/App/Stack/HandshakeMiddlewareAbstract.php b/Server/App/Stack/HandshakeMiddlewareAbstract.php new file mode 100644 index 00000000..6021e78c --- /dev/null +++ b/Server/App/Stack/HandshakeMiddlewareAbstract.php @@ -0,0 +1,66 @@ + + */ +abstract class HandshakeMiddlewareAbstract implements HttpServerInterface +{ + /** + * @var MessageComponentInterface + */ + protected $_component; + + /** + * @param MessageComponentInterface $component + */ + public function setComponent(MessageComponentInterface $component) + { + $this->_component = $component; + } + + /** + * @param ConnectionInterface $conn + * @param RequestInterface|null $request + * @return mixed + */ + public function onOpen(ConnectionInterface $conn, RequestInterface $request = null) + { + return $this->_component->onOpen($conn, $request); + } + + /** + * @param ConnectionInterface $conn + * @return mixed + */ + public function onClose(ConnectionInterface $conn) + { + return $this->_component->onClose($conn); + } + + /** + * @param ConnectionInterface $conn + * @param \Exception $e + * @return mixed + */ + public function onError(ConnectionInterface $conn, \Exception $e) + { + return $this->_component->onError($conn, $e); + } + + /** + * @param ConnectionInterface $from + * @param string $msg + * @return mixed + */ + public function onMessage(ConnectionInterface $from, $msg) + { + return $this->_component->onMessage($from, $msg); + } +} diff --git a/Server/App/Stack/OriginCheck.php b/Server/App/Stack/OriginCheck.php index 030a712c..a0682120 100644 --- a/Server/App/Stack/OriginCheck.php +++ b/Server/App/Stack/OriginCheck.php @@ -45,7 +45,7 @@ public function onOpen(ConnectionInterface $conn, RequestInterface $request = nu if (!in_array($origin, $this->allowedOrigins)) { $this->eventDispatcher->dispatch( Events::CLIENT_REJECTED, - new ClientRejectedEvent($origin, $request) + new ClientRejectedEvent('connection rejected from '. $origin, $request) ); return $this->close($conn, 403); diff --git a/Server/Type/WebSocketServer.php b/Server/Type/WebSocketServer.php index 861da907..fbe0e4bb 100644 --- a/Server/Type/WebSocketServer.php +++ b/Server/Type/WebSocketServer.php @@ -7,6 +7,7 @@ use Gos\Bundle\WebSocketBundle\Periodic\PeriodicInterface; use Gos\Bundle\WebSocketBundle\Periodic\PeriodicMemoryUsage; use Gos\Bundle\WebSocketBundle\Pusher\ServerPushHandlerRegistry; +use Gos\Bundle\WebSocketBundle\Server\App\Registry\HandshakeMiddlewareRegistry; use Gos\Bundle\WebSocketBundle\Server\App\Registry\OriginRegistry; use Gos\Bundle\WebSocketBundle\Server\App\Registry\PeriodicRegistry; use Gos\Bundle\WebSocketBundle\Server\App\WampApplication; @@ -70,6 +71,9 @@ class WebSocketServer implements ServerInterface /** @var TopicManager */ protected $topicManager; + /** @var HandshakeMiddlewareRegistry */ + protected $handshakeMiddlewareRegistry; + /** * @param LoopInterface $loop * @param EventDispatcherInterface $eventDispatcher @@ -78,6 +82,8 @@ class WebSocketServer implements ServerInterface * @param OriginRegistry $originRegistry * @param bool $originCheck * @param TopicManager $topicManager + * @param ServerPushHandlerRegistry $serverPushHandlerRegistry + * @param HandshakeMiddlewareRegistry $handshakeMiddlewareRegistry * @param LoggerInterface|null $logger */ public function __construct( @@ -89,6 +95,7 @@ public function __construct( $originCheck, TopicManager $topicManager, ServerPushHandlerRegistry $serverPushHandlerRegistry, + HandshakeMiddlewareRegistry $handshakeMiddlewareRegistry, LoggerInterface $logger = null ) { $this->loop = $loop; @@ -101,6 +108,7 @@ public function __construct( $this->topicManager = $topicManager; $this->serverPusherHandlerRegistry = $serverPushHandlerRegistry; $this->sessionHandler = new NullSessionHandler(); + $this->handshakeMiddlewareRegistry = $handshakeMiddlewareRegistry; } /** @@ -154,6 +162,13 @@ public function launch($host, $port, $profile) $stack->push('Gos\Bundle\WebSocketBundle\Server\App\Stack\OriginCheck', $allowedOrigins, $this->eventDispatcher); } + + if (!empty($this->handshakeMiddlewareRegistry->getMiddlewares())) { + foreach ($this->handshakeMiddlewareRegistry->getMiddlewares() as $middleware) { + call_user_func([$stack, 'push'], 'Gos\Bundle\WebSocketBundle\Server\App\Stack\Factory\Middleware', $middleware); + } + } + $stack ->push('Ratchet\WebSocket\WsServer') ->push('Gos\Bundle\WebSocketBundle\Server\App\Stack\WampConnectionPeriodicTimer', $this->loop)