diff --git a/README.md b/README.md index 620ad56..3b9f3a8 100644 --- a/README.md +++ b/README.md @@ -371,7 +371,9 @@ A [MCP Tool](https://modelcontextprotocol.io/docs/concepts/tools) requires the f For each deployed pipeline, Hayhooks will: - Use the pipeline wrapper `name` as MCP Tool `name` (always present). -- Use the pipeline wrapper **`run_api` method docstring** as MCP Tool `description` (if present). +- Parse **`run_api` method docstring**: + - If you use Google-style or reStructuredText-style docstrings, use the first line as MCP Tool `description` and the rest as `parameters` (if present). + - Each parameter description will be used as the `description` of the corresponding Pydantic model field (if present). - Generate a Pydantic model from the `inputSchema` using the **`run_api` method arguments as fields**. Here's an example of a PipelineWrapper implementation for the `chat_with_website` pipeline which can be used as a MCP Tool: diff --git a/pyproject.toml b/pyproject.toml index 4033a76..23e32d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "loguru", "pydantic-settings", "python-dotenv", + "docstring-parser", ] [project.optional-dependencies] diff --git a/src/hayhooks/server/utils/deploy_utils.py b/src/hayhooks/server/utils/deploy_utils.py index 22eb52a..15b9eae 100644 --- a/src/hayhooks/server/utils/deploy_utils.py +++ b/src/hayhooks/server/utils/deploy_utils.py @@ -4,15 +4,17 @@ import tempfile import traceback import sys +import docstring_parser +from docstring_parser.common import Docstring from functools import wraps from pathlib import Path from types import ModuleType -from typing import Callable, Union +from typing import Callable, Union, Any, Dict from fastapi import FastAPI, Form, HTTPException from fastapi.concurrency import run_in_threadpool from fastapi.responses import JSONResponse from fastapi.routing import APIRoute -from pydantic import create_model +from pydantic import create_model, Field from hayhooks.server.exceptions import ( PipelineAlreadyExistsError, PipelineFilesError, @@ -178,7 +180,7 @@ def load_pipeline_module(pipeline_name: str, dir_path: Union[Path, str]) -> Modu raise PipelineModuleLoadError(error_msg) from e -def create_request_model_from_callable(func: Callable, model_name: str): +def create_request_model_from_callable(func: Callable, model_name: str, docstring: Docstring): """Create a dynamic Pydantic model based on callable's signature. Args: @@ -190,14 +192,19 @@ def create_request_model_from_callable(func: Callable, model_name: str): """ params = inspect.signature(func).parameters - fields = { - name: (param.annotation, ... if param.default == param.empty else param.default) - for name, param in params.items() - } + param_docs = {p.arg_name: p.description for p in docstring.params} + + fields: Dict[str, Any] = {} + for name, param in params.items(): + default_value = ... if param.default == param.empty else param.default + description = param_docs.get(name) or f"Parameter '{name}'" + field_info = Field(default=default_value, description=description) + fields[name] = (param.annotation, field_info) + return create_model(f'{model_name}Request', **fields) -def create_response_model_from_callable(func: Callable, model_name: str): +def create_response_model_from_callable(func: Callable, model_name: str, docstring: Docstring): """Create a dynamic Pydantic model based on callable's return type. Args: @@ -209,7 +216,9 @@ def create_response_model_from_callable(func: Callable, model_name: str): """ return_type = inspect.signature(func).return_annotation - return create_model(f'{model_name}Response', result=(return_type, ...)) + return_description = docstring.returns.description if docstring.returns else None + + return create_model(f'{model_name}Response', result=(return_type, Field(..., description=return_description))) def handle_pipeline_exceptions(): @@ -270,8 +279,9 @@ async def run_endpoint_without_files(run_req: request_model) -> response_model: def add_pipeline_api_route(app: FastAPI, pipeline_name: str, pipeline_wrapper: BasePipelineWrapper) -> None: clog = log.bind(pipeline_name=pipeline_name) - RunRequest = create_request_model_from_callable(pipeline_wrapper.run_api, f'{pipeline_name}Run') - RunResponse = create_response_model_from_callable(pipeline_wrapper.run_api, f'{pipeline_name}Run') + docstring = docstring_parser.parse(inspect.getdoc(pipeline_wrapper.run_api) or "") + RunRequest = create_request_model_from_callable(pipeline_wrapper.run_api, f'{pipeline_name}Run', docstring) + RunResponse = create_response_model_from_callable(pipeline_wrapper.run_api, f'{pipeline_name}Run', docstring) run_api_params = inspect.signature(pipeline_wrapper.run_api).parameters requires_files = "files" in run_api_params @@ -289,8 +299,6 @@ def add_pipeline_api_route(app: FastAPI, pipeline_name: str, pipeline_wrapper: B if isinstance(route, APIRoute) and route.path == f"/{pipeline_name}/run": app.routes.remove(route) - docstring = inspect.getdoc(pipeline_wrapper.run_api) - app.add_api_route( path=f"/{pipeline_name}/run", endpoint=run_endpoint, @@ -298,7 +306,7 @@ def add_pipeline_api_route(app: FastAPI, pipeline_name: str, pipeline_wrapper: B name=f"{pipeline_name}_run", response_model=RunResponse, tags=["pipelines"], - description=docstring or None, + description=docstring.short_description or None, ) registry.update_metadata( @@ -383,9 +391,10 @@ def add_pipeline_to_registry( clog.debug("Running setup()") pipeline_wrapper.setup() + docstring = docstring_parser.parse(inspect.getdoc(pipeline_wrapper.run_api) or "") metadata = { - "description": inspect.getdoc(pipeline_wrapper.run_api) or "", - "request_model": create_request_model_from_callable(pipeline_wrapper.run_api, f'{pipeline_name}Run'), + "description": docstring.short_description or "", + "request_model": create_request_model_from_callable(pipeline_wrapper.run_api, f'{pipeline_name}Run', docstring), "skip_mcp": pipeline_wrapper.skip_mcp, } diff --git a/tests/test_deploy_utils.py b/tests/test_deploy_utils.py index fc54f87..22b41fb 100644 --- a/tests/test_deploy_utils.py +++ b/tests/test_deploy_utils.py @@ -1,7 +1,9 @@ -from fastapi.routing import APIRoute -from hayhooks.server.pipelines import registry import pytest import shutil +import docstring_parser +import inspect +from fastapi.routing import APIRoute +from hayhooks.server.pipelines import registry from haystack import Pipeline from pathlib import Path from typing import Callable @@ -106,28 +108,113 @@ def test_save_pipeline_files_raises_error(tmp_path): def test_create_request_model_from_callable(): def sample_func(name: str, age: int = 25, optional: str = ""): + """Sample function with docstring. + + Args: + name: The name of the person. + age: The age of the person. + optional: An optional string. + """ pass - model = create_request_model_from_callable(sample_func, "Test") + docstring = docstring_parser.parse(inspect.getdoc(sample_func) or "") + model = create_request_model_from_callable(sample_func, "Test", docstring) + schema = model.model_json_schema() assert model.__name__ == "TestRequest" - assert model.model_fields["name"].annotation == str - assert model.model_fields["name"].is_required - assert model.model_fields["age"].annotation == int - assert model.model_fields["age"].default == 25 - assert model.model_fields["optional"].annotation == str - assert model.model_fields["optional"].default == "" + assert schema["properties"]["name"]["type"] == "string" + assert "default" not in schema["properties"]["name"] + assert schema["properties"]["name"]["description"] == "The name of the person." + assert "name" in schema["required"] + + assert schema["properties"]["age"]["type"] == "integer" + assert schema["properties"]["age"]["default"] == 25 + assert schema["properties"]["age"]["description"] == "The age of the person." + assert "age" not in schema.get("required", []) + + assert schema["properties"]["optional"]["type"] == "string" + assert schema["properties"]["optional"]["default"] == "" + assert schema["properties"]["optional"]["description"] == "An optional string." + assert "optional" not in schema.get("required", []) + + +def test_create_request_model_no_docstring(): + def sample_func_no_doc(name: str, age: int = 30): + pass + + docstring = docstring_parser.parse(inspect.getdoc(sample_func_no_doc) or "") + model = create_request_model_from_callable(sample_func_no_doc, "NoDoc", docstring) + schema = model.model_json_schema() + + assert model.__name__ == "NoDocRequest" + assert schema["properties"]["name"]["type"] == "string" + assert schema["properties"]["name"]["description"] == "Parameter 'name'" + assert "name" in schema["required"] + + assert schema["properties"]["age"]["type"] == "integer" + assert schema["properties"]["age"]["default"] == 30 + assert schema["properties"]["age"]["description"] == "Parameter 'age'" + assert "age" not in schema.get("required", []) + + +def test_create_request_model_partial_docstring(): + def sample_func_partial_doc(documented_param: str, undocumented_param: int = 42): + """Sample function with partial docstring. + + Args: + documented_param: This parameter is documented. + """ + pass + + docstring = docstring_parser.parse(inspect.getdoc(sample_func_partial_doc) or "") + model = create_request_model_from_callable(sample_func_partial_doc, "PartialDoc", docstring) + schema = model.model_json_schema() + + assert model.__name__ == "PartialDocRequest" + + assert schema["properties"]["documented_param"]["type"] == "string" + assert "default" not in schema["properties"]["documented_param"] + assert schema["properties"]["documented_param"]["description"] == "This parameter is documented." + assert "documented_param" in schema["required"] + + assert schema["properties"]["undocumented_param"]["type"] == "integer" + assert schema["properties"]["undocumented_param"]["default"] == 42 + assert schema["properties"]["undocumented_param"]["description"] == "Parameter 'undocumented_param'" + assert "undocumented_param" not in schema.get("required", []) def test_create_response_model_from_callable(): def sample_func() -> dict: + """Sample function with return description. + + Returns: + A dictionary result. + """ return {"result": "test"} - model = create_response_model_from_callable(sample_func, "Test") + docstring = docstring_parser.parse(inspect.getdoc(sample_func) or "") + model = create_response_model_from_callable(sample_func, "Test", docstring) + schema = model.model_json_schema() assert model.__name__ == "TestResponse" - assert model.model_fields["result"].annotation == dict - assert model.model_fields["result"].is_required + assert schema["properties"]["result"]["type"] == "object" + assert "default" not in schema["properties"]["result"] + assert schema["properties"]["result"]["description"] == "A dictionary result." + assert "result" in schema["required"] + + +def test_create_response_model_no_docstring(): + def sample_func_no_doc() -> int: + return 1 + + docstring = docstring_parser.parse(inspect.getdoc(sample_func_no_doc) or "") + model = create_response_model_from_callable(sample_func_no_doc, "NoDoc", docstring) + schema = model.model_json_schema() + + assert model.__name__ == "NoDocResponse" + assert schema["properties"]["result"]["type"] == "integer" + assert schema["properties"]["result"].get("description") is None + assert "result" in schema["required"] def test_create_pipeline_wrapper_instance_success(): @@ -220,11 +307,13 @@ def test_deploy_pipeline_files_skip_mcp(mocker): mock_app = mocker.Mock() mock_app.routes = [] - # This pipeline wrapper has skip_mcp class attribute set to True + # This pipeline wrapper has skip_mcp class attribute set to True test_file_path = Path("tests/test_files/files/chat_with_website_mcp_skip/pipeline_wrapper.py") files = {"pipeline_wrapper.py": test_file_path.read_text()} - result = deploy_pipeline_files(app=mock_app, pipeline_name="chat_with_website_mcp_skip", files=files, save_files=False) + result = deploy_pipeline_files( + app=mock_app, pipeline_name="chat_with_website_mcp_skip", files=files, save_files=False + ) assert result == {"name": "chat_with_website_mcp_skip"} assert registry.get_metadata("chat_with_website_mcp_skip").get("skip_mcp") is True diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 77f5382..4276384 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -57,8 +57,8 @@ async def test_list_pipelines_as_tools(deploy_chat_with_website_mcp): assert tools[0].description == "Ask a question about one or more websites using a Haystack pipeline." assert tools[0].inputSchema == { 'properties': { - 'urls': {'items': {'type': 'string'}, 'title': 'Urls', 'type': 'array'}, - 'question': {'title': 'Question', 'type': 'string'}, + 'urls': {'items': {'type': 'string'}, 'title': 'Urls', 'type': 'array', 'description': "Parameter 'urls'"}, + 'question': {'title': 'Question', 'type': 'string', 'description': "Parameter 'question'"}, }, 'required': ['urls', 'question'], 'title': 'chat_with_websiteRunRequest',