diff --git a/examples/public/client_credentials.php b/examples/public/client_credentials.php index ca43fafb..1e671d3a 100644 --- a/examples/public/client_credentials.php +++ b/examples/public/client_credentials.php @@ -32,9 +32,9 @@ $app->post('/access_token', function (Request $request, Response $response) { /** @var Server $server */ $server = $this->get(Server::class); try { - return $server->respondToRequest($request); + return $server->respondToRequest($request, $response); } catch (OAuthServerException $e) { - return $e->generateHttpResponse(); + return $e->generateHttpResponse($response); } catch (\Exception $e) { return $response->withStatus(500)->write($e->getMessage()); } diff --git a/examples/public/password.php b/examples/public/password.php index d9ff5f8b..de73d63d 100644 --- a/examples/public/password.php +++ b/examples/public/password.php @@ -43,9 +43,9 @@ $app->post('/access_token', function (Request $request, Response $response) { /** @var Server $server */ $server = $this->get(Server::class); try { - return $server->respondToRequest($request); + return $server->respondToRequest($request, $response); } catch (OAuthServerException $e) { - return $e->generateHttpResponse(); + return $e->generateHttpResponse($response); } catch (\Exception $e) { return $response->withStatus(500)->write($e->getMessage()); } diff --git a/examples/public/refresh_token.php b/examples/public/refresh_token.php index 75c08139..d41e940b 100644 --- a/examples/public/refresh_token.php +++ b/examples/public/refresh_token.php @@ -43,9 +43,9 @@ $app->post('/access_token', function (Request $request, Response $response) { /** @var Server $server */ $server = $this->get(Server::class); try { - return $server->respondToRequest($request); + return $server->respondToRequest($request, $response); } catch (OAuthServerException $e) { - return $e->generateHttpResponse(); + return $e->generateHttpResponse($response); } catch (\Exception $e) { return $response->withStatus(500)->write( sprintf('
%s
', get_class($e), $e->getMessage()) diff --git a/src/Exception/OAuthServerException.php b/src/Exception/OAuthServerException.php index 28c9b799..f97e3b5a 100644 --- a/src/Exception/OAuthServerException.php +++ b/src/Exception/OAuthServerException.php @@ -204,8 +204,12 @@ class OAuthServerException extends \Exception * * @return ResponseInterface */ - public function generateHttpResponse() + public function generateHttpResponse(ResponseInterface $response = null) { + if (!$response instanceof ResponseInterface) { + $response = new Response(); + } + $headers = $this->getHttpHeaders(); $payload = [ @@ -221,12 +225,13 @@ class OAuthServerException extends \Exception $headers['Location'] = RedirectUri::make($this->redirectUri, $payload); } - $response = new Response( - 'php://memory', - $this->getHttpStatusCode(), - $headers - ); - $response->getBody()->write(json_encode($payload)); + foreach ($headers as $header => $content) { + $response->withHeader($header, $content); + } + + $response + ->withStatus($this->getHttpStatusCode()) + ->getBody()->write(json_encode($payload)); return $response; } @@ -254,8 +259,7 @@ class OAuthServerException extends \Exception if ($this->errorType === 'invalid_client') { $authScheme = null; $request = new ServerRequest(); - if ( - isset($request->getServerParams()['PHP_AUTH_USER']) && + if (isset($request->getServerParams()['PHP_AUTH_USER']) && $request->getServerParams()['PHP_AUTH_USER'] !== null ) { $authScheme = 'Basic'; diff --git a/src/ResponseTypes/BearerTokenResponse.php b/src/ResponseTypes/BearerTokenResponse.php index 85b5c235..2397286e 100644 --- a/src/ResponseTypes/BearerTokenResponse.php +++ b/src/ResponseTypes/BearerTokenResponse.php @@ -16,6 +16,7 @@ use Lcobucci\JWT\Signer\Key; use Lcobucci\JWT\Signer\Rsa\Sha256; use League\OAuth2\Server\Entities\Interfaces\RefreshTokenEntityInterface; use League\OAuth2\Server\Utils\KeyCrypt; +use Psr\Http\Message\ResponseInterface; use Psr\Http\Message\ServerRequestInterface; use Zend\Diactoros\Response; @@ -24,7 +25,7 @@ class BearerTokenResponse extends AbstractResponseType /** * {@inheritdoc} */ - public function generateHttpResponse() + public function generateHttpResponse(ResponseInterface $response) { $jwtAccessToken = (new Builder()) ->setAudience($this->accessToken->getClient()->getIdentifier()) @@ -61,16 +62,12 @@ class BearerTokenResponse extends AbstractResponseType $responseParams['refresh_token'] = $refreshToken; } - $response = new Response( - 'php://memory', - 200, - [ - 'pragma' => 'no-cache', - 'cache-control' => 'no-store', - 'content-type' => 'application/json;charset=UTF-8' - ] - ); - $response->getBody()->write(json_encode($responseParams)); + $response + ->withStatus(200) + ->withHeader('pragma', 'no-cache') + ->withHeader('cache-control', 'no-store') + ->withHeader('content-type', 'application/json;charset=UTF-8') + ->getBody()->write(json_encode($responseParams)); return $response; } diff --git a/src/ResponseTypes/ResponseTypeInterface.php b/src/ResponseTypes/ResponseTypeInterface.php index fef9cb59..00c3ed14 100644 --- a/src/ResponseTypes/ResponseTypeInterface.php +++ b/src/ResponseTypes/ResponseTypeInterface.php @@ -38,7 +38,9 @@ interface ResponseTypeInterface public function determineAccessTokenInHeader(ServerRequestInterface $request); /** + * @param ResponseInterface $response + * * @return ResponseInterface */ - public function generateHttpResponse(); + public function generateHttpResponse(ResponseInterface $response); } diff --git a/src/Server.php b/src/Server.php index b18bddf4..cd7e6e46 100644 --- a/src/Server.php +++ b/src/Server.php @@ -9,7 +9,9 @@ use League\OAuth2\Server\Exception\OAuthServerException; use League\OAuth2\Server\Grant\GrantTypeInterface; use League\OAuth2\Server\ResponseTypes\BearerTokenResponse; use League\OAuth2\Server\ResponseTypes\ResponseTypeInterface; +use Psr\Http\Message\ResponseInterface; use Psr\Http\Message\ServerRequestInterface; +use Zend\Diactoros\Response; use Zend\Diactoros\ServerRequestFactory; class Server implements EmitterAwareInterface @@ -121,16 +123,21 @@ class Server implements EmitterAwareInterface * Return an access token response * * @param \Psr\Http\Message\ServerRequestInterface $request + * @param \Psr\Http\Message\ResponseInterface $response * - * @return \League\OAuth2\Server\ResponseTypes\ResponseTypeInterface + * @return \Psr\Http\Message\ResponseInterface * @throws \League\OAuth2\Server\Exception\OAuthServerException */ - public function respondToRequest(ServerRequestInterface $request = null) + public function respondToRequest(ServerRequestInterface $request = null, ResponseInterface $response = null) { - if ($request === null) { + if (!$request instanceof ServerRequestInterface) { $request = ServerRequestFactory::fromGlobals(); } + if (!$response instanceof ResponseInterface) { + $response = new Response(); + } + $tokenResponse = null; foreach ($this->enabledGrantTypes as $grantType) { if ($grantType->canRespondToRequest($request)) { @@ -143,12 +150,30 @@ class Server implements EmitterAwareInterface } } - if ($tokenResponse instanceof ResponseTypeInterface) { - return $tokenResponse->generateHttpResponse(); - } else { - $response = OAuthServerException::unsupportedGrantType()->generateHttpResponse(); + if (!$tokenResponse instanceof ResponseTypeInterface) { + return OAuthServerException::unsupportedGrantType()->generateHttpResponse($response); } - return $response; + return $tokenResponse->generateHttpResponse($response); + } + + /** + * PSR7 middleware callable + * + * @param \Psr\Http\Message\ServerRequestInterface $request + * @param \Psr\Http\Message\ResponseInterface $response + * + * @return \Psr\Http\Message\ResponseInterface + * @throws \League\OAuth2\Server\Exception\OAuthServerException + */ + public function __invoke(ServerRequestInterface $request, ResponseInterface $response, callable $next) + { + $response = $this->respondToRequest($request, $response); + + if (in_array($response->getStatusCode(), [400, 401, 500])) { + return $response; + } + + return $next($request, $response); } }