Skip to content

Improve run_api dynamic schema creation #110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,9 @@ A [MCP Tool](https://modelcontextprotocol.io/docs/concepts/tools) requires the f
For each deployed pipeline, Hayhooks will:

- Use the pipeline wrapper `name` as MCP Tool `name` (always present).
- Use the pipeline wrapper **`run_api` method docstring** as MCP Tool `description` (if present).
- Parse **`run_api` method docstring**:
- If you use Google-style or reStructuredText-style docstrings, use the first line as MCP Tool `description` and the rest as `parameters` (if present).
- Each parameter description will be used as the `description` of the corresponding Pydantic model field (if present).
- Generate a Pydantic model from the `inputSchema` using the **`run_api` method arguments as fields**.

Here's an example of a PipelineWrapper implementation for the `chat_with_website` pipeline which can be used as a MCP Tool:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
"loguru",
"pydantic-settings",
"python-dotenv",
"docstring-parser",
]

[project.optional-dependencies]
Expand Down
41 changes: 25 additions & 16 deletions src/hayhooks/server/utils/deploy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@
import tempfile
import traceback
import sys
import docstring_parser
from docstring_parser.common import Docstring
from functools import wraps
from pathlib import Path
from types import ModuleType
from typing import Callable, Union
from typing import Callable, Union, Any, Dict
from fastapi import FastAPI, Form, HTTPException
from fastapi.concurrency import run_in_threadpool
from fastapi.responses import JSONResponse
from fastapi.routing import APIRoute
from pydantic import create_model
from pydantic import create_model, Field
from hayhooks.server.exceptions import (
PipelineAlreadyExistsError,
PipelineFilesError,
Expand Down Expand Up @@ -178,7 +180,7 @@ def load_pipeline_module(pipeline_name: str, dir_path: Union[Path, str]) -> Modu
raise PipelineModuleLoadError(error_msg) from e


def create_request_model_from_callable(func: Callable, model_name: str):
def create_request_model_from_callable(func: Callable, model_name: str, docstring: Docstring):
"""Create a dynamic Pydantic model based on callable's signature.

Args:
Expand All @@ -190,14 +192,19 @@ def create_request_model_from_callable(func: Callable, model_name: str):
"""

params = inspect.signature(func).parameters
fields = {
name: (param.annotation, ... if param.default == param.empty else param.default)
for name, param in params.items()
}
param_docs = {p.arg_name: p.description for p in docstring.params}

fields: Dict[str, Any] = {}
for name, param in params.items():
default_value = ... if param.default == param.empty else param.default
description = param_docs.get(name) or f"Parameter '{name}'"
field_info = Field(default=default_value, description=description)
fields[name] = (param.annotation, field_info)

return create_model(f'{model_name}Request', **fields)


def create_response_model_from_callable(func: Callable, model_name: str):
def create_response_model_from_callable(func: Callable, model_name: str, docstring: Docstring):
"""Create a dynamic Pydantic model based on callable's return type.

Args:
Expand All @@ -209,7 +216,9 @@ def create_response_model_from_callable(func: Callable, model_name: str):
"""

return_type = inspect.signature(func).return_annotation
return create_model(f'{model_name}Response', result=(return_type, ...))
return_description = docstring.returns.description if docstring.returns else None

return create_model(f'{model_name}Response', result=(return_type, Field(..., description=return_description)))


def handle_pipeline_exceptions():
Expand Down Expand Up @@ -270,8 +279,9 @@ async def run_endpoint_without_files(run_req: request_model) -> response_model:
def add_pipeline_api_route(app: FastAPI, pipeline_name: str, pipeline_wrapper: BasePipelineWrapper) -> None:
clog = log.bind(pipeline_name=pipeline_name)

RunRequest = create_request_model_from_callable(pipeline_wrapper.run_api, f'{pipeline_name}Run')
RunResponse = create_response_model_from_callable(pipeline_wrapper.run_api, f'{pipeline_name}Run')
docstring = docstring_parser.parse(inspect.getdoc(pipeline_wrapper.run_api) or "")
RunRequest = create_request_model_from_callable(pipeline_wrapper.run_api, f'{pipeline_name}Run', docstring)
RunResponse = create_response_model_from_callable(pipeline_wrapper.run_api, f'{pipeline_name}Run', docstring)

run_api_params = inspect.signature(pipeline_wrapper.run_api).parameters
requires_files = "files" in run_api_params
Expand All @@ -289,16 +299,14 @@ def add_pipeline_api_route(app: FastAPI, pipeline_name: str, pipeline_wrapper: B
if isinstance(route, APIRoute) and route.path == f"/{pipeline_name}/run":
app.routes.remove(route)

docstring = inspect.getdoc(pipeline_wrapper.run_api)

app.add_api_route(
path=f"/{pipeline_name}/run",
endpoint=run_endpoint,
methods=["POST"],
name=f"{pipeline_name}_run",
response_model=RunResponse,
tags=["pipelines"],
description=docstring or None,
description=docstring.short_description or None,
)

registry.update_metadata(
Expand Down Expand Up @@ -383,9 +391,10 @@ def add_pipeline_to_registry(
clog.debug("Running setup()")
pipeline_wrapper.setup()

docstring = docstring_parser.parse(inspect.getdoc(pipeline_wrapper.run_api) or "")
metadata = {
"description": inspect.getdoc(pipeline_wrapper.run_api) or "",
"request_model": create_request_model_from_callable(pipeline_wrapper.run_api, f'{pipeline_name}Run'),
"description": docstring.short_description or "",
"request_model": create_request_model_from_callable(pipeline_wrapper.run_api, f'{pipeline_name}Run', docstring),
"skip_mcp": pipeline_wrapper.skip_mcp,
}

Expand Down
117 changes: 103 additions & 14 deletions tests/test_deploy_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from fastapi.routing import APIRoute
from hayhooks.server.pipelines import registry
import pytest
import shutil
import docstring_parser
import inspect
from fastapi.routing import APIRoute
from hayhooks.server.pipelines import registry
from haystack import Pipeline
from pathlib import Path
from typing import Callable
Expand Down Expand Up @@ -106,28 +108,113 @@ def test_save_pipeline_files_raises_error(tmp_path):

def test_create_request_model_from_callable():
def sample_func(name: str, age: int = 25, optional: str = ""):
"""Sample function with docstring.

Args:
name: The name of the person.
age: The age of the person.
optional: An optional string.
"""
pass

model = create_request_model_from_callable(sample_func, "Test")
docstring = docstring_parser.parse(inspect.getdoc(sample_func) or "")
model = create_request_model_from_callable(sample_func, "Test", docstring)
schema = model.model_json_schema()

assert model.__name__ == "TestRequest"
assert model.model_fields["name"].annotation == str
assert model.model_fields["name"].is_required
assert model.model_fields["age"].annotation == int
assert model.model_fields["age"].default == 25
assert model.model_fields["optional"].annotation == str
assert model.model_fields["optional"].default == ""
assert schema["properties"]["name"]["type"] == "string"
assert "default" not in schema["properties"]["name"]
assert schema["properties"]["name"]["description"] == "The name of the person."
assert "name" in schema["required"]

assert schema["properties"]["age"]["type"] == "integer"
assert schema["properties"]["age"]["default"] == 25
assert schema["properties"]["age"]["description"] == "The age of the person."
assert "age" not in schema.get("required", [])

assert schema["properties"]["optional"]["type"] == "string"
assert schema["properties"]["optional"]["default"] == ""
assert schema["properties"]["optional"]["description"] == "An optional string."
assert "optional" not in schema.get("required", [])


def test_create_request_model_no_docstring():
def sample_func_no_doc(name: str, age: int = 30):
pass

docstring = docstring_parser.parse(inspect.getdoc(sample_func_no_doc) or "")
model = create_request_model_from_callable(sample_func_no_doc, "NoDoc", docstring)
schema = model.model_json_schema()

assert model.__name__ == "NoDocRequest"
assert schema["properties"]["name"]["type"] == "string"
assert schema["properties"]["name"]["description"] == "Parameter 'name'"
assert "name" in schema["required"]

assert schema["properties"]["age"]["type"] == "integer"
assert schema["properties"]["age"]["default"] == 30
assert schema["properties"]["age"]["description"] == "Parameter 'age'"
assert "age" not in schema.get("required", [])


def test_create_request_model_partial_docstring():
def sample_func_partial_doc(documented_param: str, undocumented_param: int = 42):
"""Sample function with partial docstring.

Args:
documented_param: This parameter is documented.
"""
pass

docstring = docstring_parser.parse(inspect.getdoc(sample_func_partial_doc) or "")
model = create_request_model_from_callable(sample_func_partial_doc, "PartialDoc", docstring)
schema = model.model_json_schema()

assert model.__name__ == "PartialDocRequest"

assert schema["properties"]["documented_param"]["type"] == "string"
assert "default" not in schema["properties"]["documented_param"]
assert schema["properties"]["documented_param"]["description"] == "This parameter is documented."
assert "documented_param" in schema["required"]

assert schema["properties"]["undocumented_param"]["type"] == "integer"
assert schema["properties"]["undocumented_param"]["default"] == 42
assert schema["properties"]["undocumented_param"]["description"] == "Parameter 'undocumented_param'"
assert "undocumented_param" not in schema.get("required", [])


def test_create_response_model_from_callable():
def sample_func() -> dict:
"""Sample function with return description.

Returns:
A dictionary result.
"""
return {"result": "test"}

model = create_response_model_from_callable(sample_func, "Test")
docstring = docstring_parser.parse(inspect.getdoc(sample_func) or "")
model = create_response_model_from_callable(sample_func, "Test", docstring)
schema = model.model_json_schema()

assert model.__name__ == "TestResponse"
assert model.model_fields["result"].annotation == dict
assert model.model_fields["result"].is_required
assert schema["properties"]["result"]["type"] == "object"
assert "default" not in schema["properties"]["result"]
assert schema["properties"]["result"]["description"] == "A dictionary result."
assert "result" in schema["required"]


def test_create_response_model_no_docstring():
def sample_func_no_doc() -> int:
return 1

docstring = docstring_parser.parse(inspect.getdoc(sample_func_no_doc) or "")
model = create_response_model_from_callable(sample_func_no_doc, "NoDoc", docstring)
schema = model.model_json_schema()

assert model.__name__ == "NoDocResponse"
assert schema["properties"]["result"]["type"] == "integer"
assert schema["properties"]["result"].get("description") is None
assert "result" in schema["required"]


def test_create_pipeline_wrapper_instance_success():
Expand Down Expand Up @@ -220,11 +307,13 @@ def test_deploy_pipeline_files_skip_mcp(mocker):
mock_app = mocker.Mock()
mock_app.routes = []

# This pipeline wrapper has skip_mcp class attribute set to True
# This pipeline wrapper has skip_mcp class attribute set to True
test_file_path = Path("tests/test_files/files/chat_with_website_mcp_skip/pipeline_wrapper.py")
files = {"pipeline_wrapper.py": test_file_path.read_text()}

result = deploy_pipeline_files(app=mock_app, pipeline_name="chat_with_website_mcp_skip", files=files, save_files=False)
result = deploy_pipeline_files(
app=mock_app, pipeline_name="chat_with_website_mcp_skip", files=files, save_files=False
)
assert result == {"name": "chat_with_website_mcp_skip"}

assert registry.get_metadata("chat_with_website_mcp_skip").get("skip_mcp") is True
4 changes: 2 additions & 2 deletions tests/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ async def test_list_pipelines_as_tools(deploy_chat_with_website_mcp):
assert tools[0].description == "Ask a question about one or more websites using a Haystack pipeline."
assert tools[0].inputSchema == {
'properties': {
'urls': {'items': {'type': 'string'}, 'title': 'Urls', 'type': 'array'},
'question': {'title': 'Question', 'type': 'string'},
'urls': {'items': {'type': 'string'}, 'title': 'Urls', 'type': 'array', 'description': "Parameter 'urls'"},
'question': {'title': 'Question', 'type': 'string', 'description': "Parameter 'question'"},
},
'required': ['urls', 'question'],
'title': 'chat_with_websiteRunRequest',
Expand Down