gsproremote/backend/app/api/vision.py

345 lines
12 KiB
Python

"""
Vision API for screen capture and streaming.
Currently implements WebSocket streaming for the map panel.
OCR and advanced vision features are gated behind configuration flags for V2.
"""
import asyncio
import json
import logging
from typing import Optional, Dict, Any
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, HTTPException, Depends
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from ..core.config import get_config
from ..core.screen import get_screen_capture, capture_region
logger = logging.getLogger(__name__)
router = APIRouter()
class StreamConfig(BaseModel):
"""Configuration for video streaming."""
fps: int = Field(30, ge=1, le=60, description="Frames per second")
quality: int = Field(85, ge=1, le=100, description="JPEG quality")
resolution: str = Field("720p", description="Stream resolution preset")
region_x: int = Field(0, ge=0, description="Region X coordinate")
region_y: int = Field(0, ge=0, description="Region Y coordinate")
region_width: int = Field(640, ge=1, description="Region width")
region_height: int = Field(480, ge=1, description="Region height")
class RegionDefinition(BaseModel):
"""Definition of a screen region for capture."""
name: str = Field(..., description="Region name")
x: int = Field(..., ge=0, description="X coordinate")
y: int = Field(..., ge=0, description="Y coordinate")
width: int = Field(..., gt=0, description="Width")
height: int = Field(..., gt=0, description="Height")
description: Optional[str] = Field(None, description="Region description")
class CaptureRequest(BaseModel):
"""Request to capture a screen region."""
x: int = Field(0, ge=0, description="X coordinate")
y: int = Field(0, ge=0, description="Y coordinate")
width: int = Field(640, gt=0, description="Width")
height: int = Field(480, gt=0, description="Height")
format: str = Field("base64", description="Output format (base64 or raw)")
quality: int = Field(85, ge=1, le=100, description="JPEG quality")
class StreamManager:
"""Manages WebSocket streaming sessions."""
def __init__(self):
self.active_streams: Dict[str, WebSocket] = {}
self.stream_configs: Dict[str, StreamConfig] = {}
async def add_stream(self, client_id: str, websocket: WebSocket, config: StreamConfig):
"""Add a new streaming session."""
await websocket.accept()
self.active_streams[client_id] = websocket
self.stream_configs[client_id] = config
logger.info(f"Stream started for client {client_id}")
async def remove_stream(self, client_id: str):
"""Remove a streaming session."""
if client_id in self.active_streams:
del self.active_streams[client_id]
del self.stream_configs[client_id]
logger.info(f"Stream stopped for client {client_id}")
async def stream_frame(self, client_id: str, frame_data: str):
"""Send a frame to a specific client."""
if client_id in self.active_streams:
websocket = self.active_streams[client_id]
try:
await websocket.send_text(frame_data)
except Exception as e:
logger.error(f"Failed to send frame to client {client_id}: {e}")
await self.remove_stream(client_id)
# Global stream manager
stream_manager = StreamManager()
@router.websocket("/ws/stream")
async def stream_video(websocket: WebSocket):
"""
WebSocket endpoint for streaming video of a screen region.
The client should send a JSON message with stream configuration:
{
"action": "start",
"config": {
"fps": 30,
"quality": 85,
"resolution": "720p",
"region_x": 0,
"region_y": 0,
"region_width": 640,
"region_height": 480
}
}
"""
config = get_config()
client_id = id(websocket)
try:
# Accept the WebSocket connection
await websocket.accept()
logger.info(f"WebSocket connected: client {client_id}")
# Wait for initial configuration message
data = await websocket.receive_text()
message = json.loads(data)
if message.get("action") != "start":
await websocket.send_json({"error": "First message must be start action"})
await websocket.close()
return
# Parse stream configuration
stream_config = StreamConfig(**message.get("config", {}))
# Override with server config if needed
stream_config.fps = min(stream_config.fps, config.capture.fps)
stream_config.quality = config.capture.quality
# Add to stream manager
await stream_manager.add_stream(str(client_id), websocket, stream_config)
# Send confirmation
await websocket.send_json({"type": "config", "data": stream_config.model_dump()})
# Get screen capture instance
capture = get_screen_capture()
# Calculate frame interval
frame_interval = 1.0 / stream_config.fps
# Streaming loop
while True:
try:
# Capture the region
image = capture.capture_region(
stream_config.region_x,
stream_config.region_y,
stream_config.region_width,
stream_config.region_height,
)
# Resize if needed
if stream_config.resolution != "native":
target_width, target_height = capture.get_resolution_preset(stream_config.resolution)
# Only resize if different from capture size
if target_width != stream_config.region_width or target_height != stream_config.region_height:
image = capture.resize_image(image, width=target_width)
# Convert to base64
base64_image = capture.image_to_base64(image, quality=stream_config.quality)
# Send frame
frame_data = json.dumps(
{"type": "frame", "data": base64_image, "timestamp": asyncio.get_event_loop().time()}
)
await websocket.send_text(frame_data)
# Wait for next frame
await asyncio.sleep(frame_interval)
except WebSocketDisconnect:
logger.info(f"Client {client_id} disconnected")
break
except Exception as e:
logger.error(f"Error in stream loop for client {client_id}: {e}")
await websocket.send_json({"type": "error", "message": str(e)})
break
except Exception as e:
logger.error(f"WebSocket error for client {client_id}: {e}")
finally:
await stream_manager.remove_stream(str(client_id))
logger.info(f"WebSocket closed: client {client_id}")
@router.post("/capture", response_model=Dict[str, Any])
async def capture_screen_region(request: CaptureRequest):
"""
Capture a single frame from a screen region.
This is useful for testing or getting a single snapshot.
"""
try:
capture = get_screen_capture()
# Capture the region
image = capture.capture_region(request.x, request.y, request.width, request.height)
if request.format == "base64":
# Convert to base64
base64_image = capture.image_to_base64(image, quality=request.quality)
return {
"success": True,
"format": "base64",
"data": base64_image,
"width": request.width,
"height": request.height,
}
else:
return {"success": False, "error": f"Unsupported format: {request.format}"}
except Exception as e:
logger.error(f"Failed to capture region: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/regions")
async def get_regions():
"""
Get predefined screen regions for capture.
Returns common regions like map area, club indicator, etc.
"""
config = get_config()
# Predefined regions for GSPro UI elements
regions = {
"map": {
"name": "Map Panel",
"x": config.capture.region_x,
"y": config.capture.region_y,
"width": config.capture.region_width,
"height": config.capture.region_height,
"description": "GSPro mini-map or expanded map view",
},
"club": {
"name": "Club Indicator",
"x": 50,
"y": 200,
"width": 200,
"height": 100,
"description": "Current club selection display",
},
"shot_info": {
"name": "Shot Information",
"x": 50,
"y": 50,
"width": 300,
"height": 150,
"description": "Shot distance and trajectory information",
},
"scorecard": {
"name": "Scorecard",
"x": 400,
"y": 100,
"width": 800,
"height": 600,
"description": "Scorecard overlay when visible",
},
}
return {"regions": regions, "total": len(regions)}
@router.post("/regions/{region_name}")
async def update_region(region_name: str, region: RegionDefinition):
"""
Update or create a screen region definition.
This allows users to define custom regions for their setup.
"""
config = get_config()
if region_name == "map":
# Update the map region in config
config.capture.region_x = region.x
config.capture.region_y = region.y
config.capture.region_width = region.width
config.capture.region_height = region.height
config.save()
return {"success": True, "message": f"Region '{region_name}' updated", "region": region.model_dump()}
else:
# For now, only map region is persisted
# V2 will add support for custom regions
return {"success": False, "message": "Custom regions not yet supported (V2 feature)"}
# OCR endpoints - gated behind vision config flag
@router.post("/ocr")
async def perform_ocr(request: CaptureRequest):
"""
Perform OCR on a screen region (V2 feature).
This endpoint is only available when vision features are enabled.
"""
config = get_config()
if not config.vision.enabled:
raise HTTPException(status_code=403, detail="Vision features are not enabled. This is a V2 feature.")
# OCR implementation will go here in V2
return {"success": False, "message": "OCR features coming in V2", "vision_enabled": config.vision.enabled}
@router.get("/markers")
async def get_markers():
"""
Get visual markers for template matching (V2 feature).
This endpoint is only available when vision features are enabled.
"""
config = get_config()
if not config.vision.enabled:
raise HTTPException(status_code=403, detail="Vision features are not enabled. This is a V2 feature.")
# Marker management will go here in V2
return {"markers": [], "message": "Marker features coming in V2", "vision_enabled": config.vision.enabled}
@router.get("/status")
async def get_vision_status():
"""Get the status of vision features."""
config = get_config()
return {
"streaming_enabled": True,
"ocr_enabled": config.vision.enabled,
"markers_enabled": config.vision.enabled,
"active_streams": len(stream_manager.active_streams),
"capture_config": {
"fps": config.capture.fps,
"quality": config.capture.quality,
"resolution": config.capture.resolution,
},
}