""" Vision API for screen capture and streaming. Currently implements WebSocket streaming for the map panel. OCR and advanced vision features are gated behind configuration flags for V2. """ import asyncio import json import logging from typing import Optional, Dict, Any from fastapi import APIRouter, WebSocket, WebSocketDisconnect, HTTPException, Depends from fastapi.responses import JSONResponse from pydantic import BaseModel, Field from ..core.config import get_config from ..core.screen import get_screen_capture, capture_region logger = logging.getLogger(__name__) router = APIRouter() class StreamConfig(BaseModel): """Configuration for video streaming.""" fps: int = Field(30, ge=1, le=60, description="Frames per second") quality: int = Field(85, ge=1, le=100, description="JPEG quality") resolution: str = Field("720p", description="Stream resolution preset") region_x: int = Field(0, ge=0, description="Region X coordinate") region_y: int = Field(0, ge=0, description="Region Y coordinate") region_width: int = Field(640, ge=1, description="Region width") region_height: int = Field(480, ge=1, description="Region height") class RegionDefinition(BaseModel): """Definition of a screen region for capture.""" name: str = Field(..., description="Region name") x: int = Field(..., ge=0, description="X coordinate") y: int = Field(..., ge=0, description="Y coordinate") width: int = Field(..., gt=0, description="Width") height: int = Field(..., gt=0, description="Height") description: Optional[str] = Field(None, description="Region description") class CaptureRequest(BaseModel): """Request to capture a screen region.""" x: int = Field(0, ge=0, description="X coordinate") y: int = Field(0, ge=0, description="Y coordinate") width: int = Field(640, gt=0, description="Width") height: int = Field(480, gt=0, description="Height") format: str = Field("base64", description="Output format (base64 or raw)") quality: int = Field(85, ge=1, le=100, description="JPEG quality") class StreamManager: """Manages WebSocket streaming sessions.""" def __init__(self): self.active_streams: Dict[str, WebSocket] = {} self.stream_configs: Dict[str, StreamConfig] = {} async def add_stream(self, client_id: str, websocket: WebSocket, config: StreamConfig): """Add a new streaming session.""" await websocket.accept() self.active_streams[client_id] = websocket self.stream_configs[client_id] = config logger.info(f"Stream started for client {client_id}") async def remove_stream(self, client_id: str): """Remove a streaming session.""" if client_id in self.active_streams: del self.active_streams[client_id] del self.stream_configs[client_id] logger.info(f"Stream stopped for client {client_id}") async def stream_frame(self, client_id: str, frame_data: str): """Send a frame to a specific client.""" if client_id in self.active_streams: websocket = self.active_streams[client_id] try: await websocket.send_text(frame_data) except Exception as e: logger.error(f"Failed to send frame to client {client_id}: {e}") await self.remove_stream(client_id) # Global stream manager stream_manager = StreamManager() @router.websocket("/ws/stream") async def stream_video(websocket: WebSocket): """ WebSocket endpoint for streaming video of a screen region. The client should send a JSON message with stream configuration: { "action": "start", "config": { "fps": 30, "quality": 85, "resolution": "720p", "region_x": 0, "region_y": 0, "region_width": 640, "region_height": 480 } } """ config = get_config() client_id = id(websocket) try: # Accept the WebSocket connection await websocket.accept() logger.info(f"WebSocket connected: client {client_id}") # Wait for initial configuration message data = await websocket.receive_text() message = json.loads(data) if message.get("action") != "start": await websocket.send_json({"error": "First message must be start action"}) await websocket.close() return # Parse stream configuration stream_config = StreamConfig(**message.get("config", {})) # Override with server config if needed stream_config.fps = min(stream_config.fps, config.capture.fps) stream_config.quality = config.capture.quality # Add to stream manager await stream_manager.add_stream(str(client_id), websocket, stream_config) # Send confirmation await websocket.send_json({"type": "config", "data": stream_config.model_dump()}) # Get screen capture instance capture = get_screen_capture() # Calculate frame interval frame_interval = 1.0 / stream_config.fps # Streaming loop while True: try: # Capture the region image = capture.capture_region( stream_config.region_x, stream_config.region_y, stream_config.region_width, stream_config.region_height, ) # Resize if needed if stream_config.resolution != "native": target_width, target_height = capture.get_resolution_preset(stream_config.resolution) # Only resize if different from capture size if target_width != stream_config.region_width or target_height != stream_config.region_height: image = capture.resize_image(image, width=target_width) # Convert to base64 base64_image = capture.image_to_base64(image, quality=stream_config.quality) # Send frame frame_data = json.dumps( {"type": "frame", "data": base64_image, "timestamp": asyncio.get_event_loop().time()} ) await websocket.send_text(frame_data) # Wait for next frame await asyncio.sleep(frame_interval) except WebSocketDisconnect: logger.info(f"Client {client_id} disconnected") break except Exception as e: logger.error(f"Error in stream loop for client {client_id}: {e}") await websocket.send_json({"type": "error", "message": str(e)}) break except Exception as e: logger.error(f"WebSocket error for client {client_id}: {e}") finally: await stream_manager.remove_stream(str(client_id)) logger.info(f"WebSocket closed: client {client_id}") @router.post("/capture", response_model=Dict[str, Any]) async def capture_screen_region(request: CaptureRequest): """ Capture a single frame from a screen region. This is useful for testing or getting a single snapshot. """ try: capture = get_screen_capture() # Capture the region image = capture.capture_region(request.x, request.y, request.width, request.height) if request.format == "base64": # Convert to base64 base64_image = capture.image_to_base64(image, quality=request.quality) return { "success": True, "format": "base64", "data": base64_image, "width": request.width, "height": request.height, } else: return {"success": False, "error": f"Unsupported format: {request.format}"} except Exception as e: logger.error(f"Failed to capture region: {e}") raise HTTPException(status_code=500, detail=str(e)) @router.get("/regions") async def get_regions(): """ Get predefined screen regions for capture. Returns common regions like map area, club indicator, etc. """ config = get_config() # Predefined regions for GSPro UI elements regions = { "map": { "name": "Map Panel", "x": config.capture.region_x, "y": config.capture.region_y, "width": config.capture.region_width, "height": config.capture.region_height, "description": "GSPro mini-map or expanded map view", }, "club": { "name": "Club Indicator", "x": 50, "y": 200, "width": 200, "height": 100, "description": "Current club selection display", }, "shot_info": { "name": "Shot Information", "x": 50, "y": 50, "width": 300, "height": 150, "description": "Shot distance and trajectory information", }, "scorecard": { "name": "Scorecard", "x": 400, "y": 100, "width": 800, "height": 600, "description": "Scorecard overlay when visible", }, } return {"regions": regions, "total": len(regions)} @router.post("/regions/{region_name}") async def update_region(region_name: str, region: RegionDefinition): """ Update or create a screen region definition. This allows users to define custom regions for their setup. """ config = get_config() if region_name == "map": # Update the map region in config config.capture.region_x = region.x config.capture.region_y = region.y config.capture.region_width = region.width config.capture.region_height = region.height config.save() return {"success": True, "message": f"Region '{region_name}' updated", "region": region.model_dump()} else: # For now, only map region is persisted # V2 will add support for custom regions return {"success": False, "message": "Custom regions not yet supported (V2 feature)"} # OCR endpoints - gated behind vision config flag @router.post("/ocr") async def perform_ocr(request: CaptureRequest): """ Perform OCR on a screen region (V2 feature). This endpoint is only available when vision features are enabled. """ config = get_config() if not config.vision.enabled: raise HTTPException(status_code=403, detail="Vision features are not enabled. This is a V2 feature.") # OCR implementation will go here in V2 return {"success": False, "message": "OCR features coming in V2", "vision_enabled": config.vision.enabled} @router.get("/markers") async def get_markers(): """ Get visual markers for template matching (V2 feature). This endpoint is only available when vision features are enabled. """ config = get_config() if not config.vision.enabled: raise HTTPException(status_code=403, detail="Vision features are not enabled. This is a V2 feature.") # Marker management will go here in V2 return {"markers": [], "message": "Marker features coming in V2", "vision_enabled": config.vision.enabled} @router.get("/status") async def get_vision_status(): """Get the status of vision features.""" config = get_config() return { "streaming_enabled": True, "ocr_enabled": config.vision.enabled, "markers_enabled": config.vision.enabled, "active_streams": len(stream_manager.active_streams), "capture_config": { "fps": config.capture.fps, "quality": config.capture.quality, "resolution": config.capture.resolution, }, }