Add strict mode to pyright (#315)

* Add strict mode to pyright

* Apply UP rule

* fix readme

* More correct

* Leave wrong Context for now

* Add strict mode to pyright

* Apply UP rule

* fix readme

* fix

* ignore
This commit is contained in:
Marcelo Trylesinski
2025-03-20 09:13:08 +00:00
committed by GitHub
parent 5a54d82459
commit ae77772ea8
27 changed files with 194 additions and 133 deletions

View File

@@ -2,8 +2,8 @@
import inspect
import json
from collections.abc import Callable
from typing import Any, Awaitable, Literal, Sequence
from collections.abc import Awaitable, Callable, Sequence
from typing import Any, Literal
import pydantic_core
from pydantic import BaseModel, Field, TypeAdapter, validate_call
@@ -19,7 +19,7 @@ class Message(BaseModel):
role: Literal["user", "assistant"]
content: CONTENT_TYPES
def __init__(self, content: str | CONTENT_TYPES, **kwargs):
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
if isinstance(content, str):
content = TextContent(type="text", text=content)
super().__init__(content=content, **kwargs)
@@ -30,7 +30,7 @@ class UserMessage(Message):
role: Literal["user", "assistant"] = "user"
def __init__(self, content: str | CONTENT_TYPES, **kwargs):
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
super().__init__(content=content, **kwargs)
@@ -39,11 +39,13 @@ class AssistantMessage(Message):
role: Literal["user", "assistant"] = "assistant"
def __init__(self, content: str | CONTENT_TYPES, **kwargs):
def __init__(self, content: str | CONTENT_TYPES, **kwargs: Any):
super().__init__(content=content, **kwargs)
message_validator = TypeAdapter(UserMessage | AssistantMessage)
message_validator = TypeAdapter[UserMessage | AssistantMessage](
UserMessage | AssistantMessage
)
SyncPromptResult = (
str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]]
@@ -73,12 +75,12 @@ class Prompt(BaseModel):
arguments: list[PromptArgument] | None = Field(
None, description="Arguments that can be passed to the prompt"
)
fn: Callable = Field(exclude=True)
fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True)
@classmethod
def from_function(
cls,
fn: Callable[..., PromptResult],
fn: Callable[..., PromptResult | Awaitable[PromptResult]],
name: str | None = None,
description: str | None = None,
) -> "Prompt":
@@ -99,7 +101,7 @@ class Prompt(BaseModel):
parameters = TypeAdapter(fn).json_schema()
# Convert parameters to PromptArguments
arguments = []
arguments: list[PromptArgument] = []
if "properties" in parameters:
for param_name, param in parameters["properties"].items():
required = param_name in parameters.get("required", [])
@@ -138,25 +140,23 @@ class Prompt(BaseModel):
result = await result
# Validate messages
if not isinstance(result, (list, tuple)):
if not isinstance(result, list | tuple):
result = [result]
# Convert result to messages
messages = []
for msg in result:
messages: list[Message] = []
for msg in result: # type: ignore[reportUnknownVariableType]
try:
if isinstance(msg, Message):
messages.append(msg)
elif isinstance(msg, dict):
msg = message_validator.validate_python(msg)
messages.append(msg)
messages.append(message_validator.validate_python(msg))
elif isinstance(msg, str):
messages.append(
UserMessage(content=TextContent(type="text", text=msg))
)
content = TextContent(type="text", text=msg)
messages.append(UserMessage(content=content))
else:
msg = json.dumps(pydantic_core.to_jsonable_python(msg))
messages.append(Message(role="user", content=msg))
content = json.dumps(pydantic_core.to_jsonable_python(msg))
messages.append(Message(role="user", content=content))
except Exception:
raise ValueError(
f"Could not convert prompt result to message: {msg}"