Spaces:
Running
Running
| import httpx | |
| from .schemas import ( | |
| EndpointInfo, | |
| Parameter, | |
| RequestBody, | |
| Response, | |
| ) | |
| JSON_TYPE_MAP = { | |
| "string": "string", | |
| "integer": "integer", | |
| "number": "number", | |
| "boolean": "boolean", | |
| "array": "array", | |
| "object": "object", | |
| } | |
| def _resolve_ref(spec: dict, ref: str) -> dict: | |
| """Resolve a $ref pointer like '#/components/schemas/MyModel'.""" | |
| parts = ref.lstrip("#/").split("/") | |
| node = spec | |
| for part in parts: | |
| node = node[part] | |
| return node | |
| def _get_type(schema: dict, spec: dict) -> str: | |
| if "$ref" in schema: | |
| schema = _resolve_ref(spec, schema["$ref"]) | |
| if "type" in schema: | |
| return JSON_TYPE_MAP.get(schema["type"], schema["type"]) | |
| if "anyOf" in schema: | |
| types = [_get_type(s, spec) for s in schema["anyOf"] if s.get("type") != "null"] | |
| return types[0] if len(types) == 1 else " | ".join(types) | |
| return "unknown" | |
| def _extract_fields(schema: dict, spec: dict) -> dict[str, str]: | |
| """Extract field_name -> type from an object schema.""" | |
| if "$ref" in schema: | |
| schema = _resolve_ref(spec, schema["$ref"]) | |
| properties = schema.get("properties", {}) | |
| return {name: _get_type(prop, spec) for name, prop in properties.items()} | |
| def _parse_parameters(params: list[dict], spec: dict) -> list[Parameter]: | |
| result = [] | |
| for p in params: | |
| if "$ref" in p: | |
| p = _resolve_ref(spec, p["$ref"]) | |
| schema = p.get("schema", {}) | |
| result.append( | |
| Parameter( | |
| name=p["name"], | |
| location=p["in"], | |
| type=_get_type(schema, spec), | |
| required=p.get("required", False), | |
| description=p.get("description"), | |
| ) | |
| ) | |
| return result | |
| def _parse_request_body(body: dict | None, spec: dict) -> RequestBody | None: | |
| if not body: | |
| return None | |
| if "$ref" in body: | |
| body = _resolve_ref(spec, body["$ref"]) | |
| content = body.get("content", {}) | |
| for content_type, media in content.items(): | |
| schema = media.get("schema", {}) | |
| fields = _extract_fields(schema, spec) | |
| return RequestBody(content_type=content_type, fields=fields) | |
| return None | |
| def _parse_responses(responses: dict, spec: dict) -> list[Response]: | |
| result = [] | |
| for status_code, resp in responses.items(): | |
| if "$ref" in resp: | |
| resp = _resolve_ref(spec, resp["$ref"]) | |
| content = resp.get("content", {}) | |
| if content: | |
| for content_type, media in content.items(): | |
| schema = media.get("schema", {}) | |
| fields = _extract_fields(schema, spec) | |
| result.append( | |
| Response( | |
| status_code=str(status_code), | |
| description=resp.get("description"), | |
| content_type=content_type, | |
| fields=fields, | |
| ) | |
| ) | |
| break | |
| else: | |
| result.append( | |
| Response( | |
| status_code=str(status_code), | |
| description=resp.get("description"), | |
| fields={}, | |
| ) | |
| ) | |
| return result | |
| def parse_endpoint(spec: dict, path: str, method: str, operation: dict) -> EndpointInfo: | |
| return EndpointInfo( | |
| path=path, | |
| method=method.upper(), | |
| summary=operation.get("summary"), | |
| description=operation.get("description"), | |
| operation_id=operation.get("operationId"), | |
| parameters=_parse_parameters(operation.get("parameters", []), spec), | |
| request_body=_parse_request_body(operation.get("requestBody"), spec), | |
| responses=_parse_responses(operation.get("responses", {}), spec), | |
| ) | |
| def _normalize_path(p: str) -> str: | |
| """Strip trailing slashes for consistent comparison, but keep root '/'.""" | |
| return p.rstrip("/") or "/" | |
| async def fetch_and_parse(spec_url: str, path: str | None = None, method: str | None = None) -> list[EndpointInfo]: | |
| async with httpx.AsyncClient() as client: | |
| resp = await client.get(str(spec_url), follow_redirects=True) | |
| resp.raise_for_status() | |
| spec = resp.json() | |
| normalized_path = _normalize_path(path) if path else None | |
| endpoints: list[EndpointInfo] = [] | |
| for ep_path, methods in spec.get("paths", {}).items(): | |
| if normalized_path and _normalize_path(ep_path) != normalized_path: | |
| continue | |
| for ep_method, operation in methods.items(): | |
| if ep_method in ("parameters", "summary", "description", "servers"): | |
| continue | |
| if method and ep_method.upper() != method.upper(): | |
| continue | |
| endpoints.append(parse_endpoint(spec, ep_path, ep_method, operation)) | |
| return endpoints | |