Support Cursor OAuth client registration (#895)

This commit is contained in:
Sam Tombury
2025-06-07 15:24:11 +01:00
committed by GitHub
parent 8276632caa
commit 2bce10bdb1
5 changed files with 16 additions and 16 deletions

View File

@@ -2,7 +2,7 @@ import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Literal from typing import Any, Literal
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError from pydantic import AnyUrl, BaseModel, Field, RootModel, ValidationError
from starlette.datastructures import FormData, QueryParams from starlette.datastructures import FormData, QueryParams
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import RedirectResponse, Response from starlette.responses import RedirectResponse, Response
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class AuthorizationRequest(BaseModel): class AuthorizationRequest(BaseModel):
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1 # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
client_id: str = Field(..., description="The client ID") client_id: str = Field(..., description="The client ID")
redirect_uri: AnyHttpUrl | None = Field( redirect_uri: AnyUrl | None = Field(
None, description="URL to redirect to after authorization" None, description="URL to redirect to after authorization"
) )
@@ -68,8 +68,8 @@ def best_effort_extract_string(
return None return None
class AnyHttpUrlModel(RootModel[AnyHttpUrl]): class AnyUrlModel(RootModel[AnyUrl]):
root: AnyHttpUrl root: AnyUrl
@dataclass @dataclass
@@ -116,7 +116,7 @@ class AuthorizationHandler:
if params is not None and "redirect_uri" not in params: if params is not None and "redirect_uri" not in params:
raw_redirect_uri = None raw_redirect_uri = None
else: else:
raw_redirect_uri = AnyHttpUrlModel.model_validate( raw_redirect_uri = AnyUrlModel.model_validate(
best_effort_extract_string("redirect_uri", params) best_effort_extract_string("redirect_uri", params)
).root ).root
redirect_uri = client.validate_redirect_uri(raw_redirect_uri) redirect_uri = client.validate_redirect_uri(raw_redirect_uri)

View File

@@ -4,7 +4,7 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal
from pydantic import AnyHttpUrl, BaseModel, Field, RootModel, ValidationError from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError
from starlette.requests import Request from starlette.requests import Request
from mcp.server.auth.errors import ( from mcp.server.auth.errors import (
@@ -27,7 +27,7 @@ class AuthorizationCodeRequest(BaseModel):
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3
grant_type: Literal["authorization_code"] grant_type: Literal["authorization_code"]
code: str = Field(..., description="The authorization code") code: str = Field(..., description="The authorization code")
redirect_uri: AnyHttpUrl | None = Field( redirect_uri: AnyUrl | None = Field(
None, description="Must be the same as redirect URI provided in /authorize" None, description="Must be the same as redirect URI provided in /authorize"
) )
client_id: str client_id: str

View File

@@ -2,7 +2,7 @@ from dataclasses import dataclass
from typing import Generic, Literal, Protocol, TypeVar from typing import Generic, Literal, Protocol, TypeVar
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from pydantic import AnyHttpUrl, BaseModel from pydantic import AnyUrl, BaseModel
from mcp.shared.auth import ( from mcp.shared.auth import (
OAuthClientInformationFull, OAuthClientInformationFull,
@@ -14,7 +14,7 @@ class AuthorizationParams(BaseModel):
state: str | None state: str | None
scopes: list[str] | None scopes: list[str] | None
code_challenge: str code_challenge: str
redirect_uri: AnyHttpUrl redirect_uri: AnyUrl
redirect_uri_provided_explicitly: bool redirect_uri_provided_explicitly: bool
@@ -24,7 +24,7 @@ class AuthorizationCode(BaseModel):
expires_at: float expires_at: float
client_id: str client_id: str
code_challenge: str code_challenge: str
redirect_uri: AnyHttpUrl redirect_uri: AnyUrl
redirect_uri_provided_explicitly: bool redirect_uri_provided_explicitly: bool

View File

@@ -1,6 +1,6 @@
from typing import Any, Literal from typing import Any, Literal
from pydantic import AnyHttpUrl, BaseModel, Field from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field
class OAuthToken(BaseModel): class OAuthToken(BaseModel):
@@ -32,7 +32,7 @@ class OAuthClientMetadata(BaseModel):
for the full specification. for the full specification.
""" """
redirect_uris: list[AnyHttpUrl] = Field(..., min_length=1) redirect_uris: list[AnyUrl] = Field(..., min_length=1)
# token_endpoint_auth_method: this implementation only supports none & # token_endpoint_auth_method: this implementation only supports none &
# client_secret_post; # client_secret_post;
# ie: we do not support client_secret_basic # ie: we do not support client_secret_basic
@@ -71,7 +71,7 @@ class OAuthClientMetadata(BaseModel):
raise InvalidScopeError(f"Client was not registered with scope {scope}") raise InvalidScopeError(f"Client was not registered with scope {scope}")
return requested_scopes return requested_scopes
def validate_redirect_uri(self, redirect_uri: AnyHttpUrl | None) -> AnyHttpUrl: def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl:
if redirect_uri is not None: if redirect_uri is not None:
# Validate redirect_uri against client's registered redirect URIs # Validate redirect_uri against client's registered redirect URIs
if redirect_uri not in self.redirect_uris: if redirect_uri not in self.redirect_uris:

View File

@@ -11,7 +11,7 @@ from urllib.parse import parse_qs, urlparse
import httpx import httpx
import pytest import pytest
from inline_snapshot import snapshot from inline_snapshot import snapshot
from pydantic import AnyHttpUrl from pydantic import AnyHttpUrl, AnyUrl
from mcp.client.auth import OAuthClientProvider from mcp.client.auth import OAuthClientProvider
from mcp.server.auth.routes import build_metadata from mcp.server.auth.routes import build_metadata
@@ -52,7 +52,7 @@ def mock_storage():
@pytest.fixture @pytest.fixture
def client_metadata(): def client_metadata():
return OAuthClientMetadata( return OAuthClientMetadata(
redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")], redirect_uris=[AnyUrl("http://localhost:3000/callback")],
client_name="Test Client", client_name="Test Client",
grant_types=["authorization_code", "refresh_token"], grant_types=["authorization_code", "refresh_token"],
response_types=["code"], response_types=["code"],
@@ -79,7 +79,7 @@ def oauth_client_info():
return OAuthClientInformationFull( return OAuthClientInformationFull(
client_id="test_client_id", client_id="test_client_id",
client_secret="test_client_secret", client_secret="test_client_secret",
redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")], redirect_uris=[AnyUrl("http://localhost:3000/callback")],
client_name="Test Client", client_name="Test Client",
grant_types=["authorization_code", "refresh_token"], grant_types=["authorization_code", "refresh_token"],
response_types=["code"], response_types=["code"],