Initial commit: GSPro Remote MVP - Phase 1 complete
This commit is contained in:
commit
74ca4b38eb
50 changed files with 12818 additions and 0 deletions
345
backend/app/api/vision.py
Normal file
345
backend/app/api/vision.py
Normal file
|
|
@ -0,0 +1,345 @@
|
|||
"""
|
||||
Vision API for screen capture and streaming.
|
||||
Currently implements WebSocket streaming for the map panel.
|
||||
OCR and advanced vision features are gated behind configuration flags for V2.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, HTTPException, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..core.config import get_config
|
||||
from ..core.screen import get_screen_capture, capture_region
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class StreamConfig(BaseModel):
|
||||
"""Configuration for video streaming."""
|
||||
|
||||
fps: int = Field(30, ge=1, le=60, description="Frames per second")
|
||||
quality: int = Field(85, ge=1, le=100, description="JPEG quality")
|
||||
resolution: str = Field("720p", description="Stream resolution preset")
|
||||
region_x: int = Field(0, ge=0, description="Region X coordinate")
|
||||
region_y: int = Field(0, ge=0, description="Region Y coordinate")
|
||||
region_width: int = Field(640, ge=1, description="Region width")
|
||||
region_height: int = Field(480, ge=1, description="Region height")
|
||||
|
||||
|
||||
class RegionDefinition(BaseModel):
|
||||
"""Definition of a screen region for capture."""
|
||||
|
||||
name: str = Field(..., description="Region name")
|
||||
x: int = Field(..., ge=0, description="X coordinate")
|
||||
y: int = Field(..., ge=0, description="Y coordinate")
|
||||
width: int = Field(..., gt=0, description="Width")
|
||||
height: int = Field(..., gt=0, description="Height")
|
||||
description: Optional[str] = Field(None, description="Region description")
|
||||
|
||||
|
||||
class CaptureRequest(BaseModel):
|
||||
"""Request to capture a screen region."""
|
||||
|
||||
x: int = Field(0, ge=0, description="X coordinate")
|
||||
y: int = Field(0, ge=0, description="Y coordinate")
|
||||
width: int = Field(640, gt=0, description="Width")
|
||||
height: int = Field(480, gt=0, description="Height")
|
||||
format: str = Field("base64", description="Output format (base64 or raw)")
|
||||
quality: int = Field(85, ge=1, le=100, description="JPEG quality")
|
||||
|
||||
|
||||
class StreamManager:
|
||||
"""Manages WebSocket streaming sessions."""
|
||||
|
||||
def __init__(self):
|
||||
self.active_streams: Dict[str, WebSocket] = {}
|
||||
self.stream_configs: Dict[str, StreamConfig] = {}
|
||||
|
||||
async def add_stream(self, client_id: str, websocket: WebSocket, config: StreamConfig):
|
||||
"""Add a new streaming session."""
|
||||
await websocket.accept()
|
||||
self.active_streams[client_id] = websocket
|
||||
self.stream_configs[client_id] = config
|
||||
logger.info(f"Stream started for client {client_id}")
|
||||
|
||||
async def remove_stream(self, client_id: str):
|
||||
"""Remove a streaming session."""
|
||||
if client_id in self.active_streams:
|
||||
del self.active_streams[client_id]
|
||||
del self.stream_configs[client_id]
|
||||
logger.info(f"Stream stopped for client {client_id}")
|
||||
|
||||
async def stream_frame(self, client_id: str, frame_data: str):
|
||||
"""Send a frame to a specific client."""
|
||||
if client_id in self.active_streams:
|
||||
websocket = self.active_streams[client_id]
|
||||
try:
|
||||
await websocket.send_text(frame_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send frame to client {client_id}: {e}")
|
||||
await self.remove_stream(client_id)
|
||||
|
||||
|
||||
# Global stream manager
|
||||
stream_manager = StreamManager()
|
||||
|
||||
|
||||
@router.websocket("/ws/stream")
|
||||
async def stream_video(websocket: WebSocket):
|
||||
"""
|
||||
WebSocket endpoint for streaming video of a screen region.
|
||||
|
||||
The client should send a JSON message with stream configuration:
|
||||
{
|
||||
"action": "start",
|
||||
"config": {
|
||||
"fps": 30,
|
||||
"quality": 85,
|
||||
"resolution": "720p",
|
||||
"region_x": 0,
|
||||
"region_y": 0,
|
||||
"region_width": 640,
|
||||
"region_height": 480
|
||||
}
|
||||
}
|
||||
"""
|
||||
config = get_config()
|
||||
client_id = id(websocket)
|
||||
|
||||
try:
|
||||
# Accept the WebSocket connection
|
||||
await websocket.accept()
|
||||
logger.info(f"WebSocket connected: client {client_id}")
|
||||
|
||||
# Wait for initial configuration message
|
||||
data = await websocket.receive_text()
|
||||
message = json.loads(data)
|
||||
|
||||
if message.get("action") != "start":
|
||||
await websocket.send_json({"error": "First message must be start action"})
|
||||
await websocket.close()
|
||||
return
|
||||
|
||||
# Parse stream configuration
|
||||
stream_config = StreamConfig(**message.get("config", {}))
|
||||
|
||||
# Override with server config if needed
|
||||
stream_config.fps = min(stream_config.fps, config.capture.fps)
|
||||
stream_config.quality = config.capture.quality
|
||||
|
||||
# Add to stream manager
|
||||
await stream_manager.add_stream(str(client_id), websocket, stream_config)
|
||||
|
||||
# Send confirmation
|
||||
await websocket.send_json({"type": "config", "data": stream_config.model_dump()})
|
||||
|
||||
# Get screen capture instance
|
||||
capture = get_screen_capture()
|
||||
|
||||
# Calculate frame interval
|
||||
frame_interval = 1.0 / stream_config.fps
|
||||
|
||||
# Streaming loop
|
||||
while True:
|
||||
try:
|
||||
# Capture the region
|
||||
image = capture.capture_region(
|
||||
stream_config.region_x,
|
||||
stream_config.region_y,
|
||||
stream_config.region_width,
|
||||
stream_config.region_height,
|
||||
)
|
||||
|
||||
# Resize if needed
|
||||
if stream_config.resolution != "native":
|
||||
target_width, target_height = capture.get_resolution_preset(stream_config.resolution)
|
||||
# Only resize if different from capture size
|
||||
if target_width != stream_config.region_width or target_height != stream_config.region_height:
|
||||
image = capture.resize_image(image, width=target_width)
|
||||
|
||||
# Convert to base64
|
||||
base64_image = capture.image_to_base64(image, quality=stream_config.quality)
|
||||
|
||||
# Send frame
|
||||
frame_data = json.dumps(
|
||||
{"type": "frame", "data": base64_image, "timestamp": asyncio.get_event_loop().time()}
|
||||
)
|
||||
|
||||
await websocket.send_text(frame_data)
|
||||
|
||||
# Wait for next frame
|
||||
await asyncio.sleep(frame_interval)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"Client {client_id} disconnected")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in stream loop for client {client_id}: {e}")
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error for client {client_id}: {e}")
|
||||
finally:
|
||||
await stream_manager.remove_stream(str(client_id))
|
||||
logger.info(f"WebSocket closed: client {client_id}")
|
||||
|
||||
|
||||
@router.post("/capture", response_model=Dict[str, Any])
|
||||
async def capture_screen_region(request: CaptureRequest):
|
||||
"""
|
||||
Capture a single frame from a screen region.
|
||||
|
||||
This is useful for testing or getting a single snapshot.
|
||||
"""
|
||||
try:
|
||||
capture = get_screen_capture()
|
||||
|
||||
# Capture the region
|
||||
image = capture.capture_region(request.x, request.y, request.width, request.height)
|
||||
|
||||
if request.format == "base64":
|
||||
# Convert to base64
|
||||
base64_image = capture.image_to_base64(image, quality=request.quality)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"format": "base64",
|
||||
"data": base64_image,
|
||||
"width": request.width,
|
||||
"height": request.height,
|
||||
}
|
||||
else:
|
||||
return {"success": False, "error": f"Unsupported format: {request.format}"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to capture region: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/regions")
|
||||
async def get_regions():
|
||||
"""
|
||||
Get predefined screen regions for capture.
|
||||
|
||||
Returns common regions like map area, club indicator, etc.
|
||||
"""
|
||||
config = get_config()
|
||||
|
||||
# Predefined regions for GSPro UI elements
|
||||
regions = {
|
||||
"map": {
|
||||
"name": "Map Panel",
|
||||
"x": config.capture.region_x,
|
||||
"y": config.capture.region_y,
|
||||
"width": config.capture.region_width,
|
||||
"height": config.capture.region_height,
|
||||
"description": "GSPro mini-map or expanded map view",
|
||||
},
|
||||
"club": {
|
||||
"name": "Club Indicator",
|
||||
"x": 50,
|
||||
"y": 200,
|
||||
"width": 200,
|
||||
"height": 100,
|
||||
"description": "Current club selection display",
|
||||
},
|
||||
"shot_info": {
|
||||
"name": "Shot Information",
|
||||
"x": 50,
|
||||
"y": 50,
|
||||
"width": 300,
|
||||
"height": 150,
|
||||
"description": "Shot distance and trajectory information",
|
||||
},
|
||||
"scorecard": {
|
||||
"name": "Scorecard",
|
||||
"x": 400,
|
||||
"y": 100,
|
||||
"width": 800,
|
||||
"height": 600,
|
||||
"description": "Scorecard overlay when visible",
|
||||
},
|
||||
}
|
||||
|
||||
return {"regions": regions, "total": len(regions)}
|
||||
|
||||
|
||||
@router.post("/regions/{region_name}")
|
||||
async def update_region(region_name: str, region: RegionDefinition):
|
||||
"""
|
||||
Update or create a screen region definition.
|
||||
|
||||
This allows users to define custom regions for their setup.
|
||||
"""
|
||||
config = get_config()
|
||||
|
||||
if region_name == "map":
|
||||
# Update the map region in config
|
||||
config.capture.region_x = region.x
|
||||
config.capture.region_y = region.y
|
||||
config.capture.region_width = region.width
|
||||
config.capture.region_height = region.height
|
||||
config.save()
|
||||
|
||||
return {"success": True, "message": f"Region '{region_name}' updated", "region": region.model_dump()}
|
||||
else:
|
||||
# For now, only map region is persisted
|
||||
# V2 will add support for custom regions
|
||||
return {"success": False, "message": "Custom regions not yet supported (V2 feature)"}
|
||||
|
||||
|
||||
# OCR endpoints - gated behind vision config flag
|
||||
@router.post("/ocr")
|
||||
async def perform_ocr(request: CaptureRequest):
|
||||
"""
|
||||
Perform OCR on a screen region (V2 feature).
|
||||
|
||||
This endpoint is only available when vision features are enabled.
|
||||
"""
|
||||
config = get_config()
|
||||
|
||||
if not config.vision.enabled:
|
||||
raise HTTPException(status_code=403, detail="Vision features are not enabled. This is a V2 feature.")
|
||||
|
||||
# OCR implementation will go here in V2
|
||||
return {"success": False, "message": "OCR features coming in V2", "vision_enabled": config.vision.enabled}
|
||||
|
||||
|
||||
@router.get("/markers")
|
||||
async def get_markers():
|
||||
"""
|
||||
Get visual markers for template matching (V2 feature).
|
||||
|
||||
This endpoint is only available when vision features are enabled.
|
||||
"""
|
||||
config = get_config()
|
||||
|
||||
if not config.vision.enabled:
|
||||
raise HTTPException(status_code=403, detail="Vision features are not enabled. This is a V2 feature.")
|
||||
|
||||
# Marker management will go here in V2
|
||||
return {"markers": [], "message": "Marker features coming in V2", "vision_enabled": config.vision.enabled}
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
async def get_vision_status():
|
||||
"""Get the status of vision features."""
|
||||
config = get_config()
|
||||
|
||||
return {
|
||||
"streaming_enabled": True,
|
||||
"ocr_enabled": config.vision.enabled,
|
||||
"markers_enabled": config.vision.enabled,
|
||||
"active_streams": len(stream_manager.active_streams),
|
||||
"capture_config": {
|
||||
"fps": config.capture.fps,
|
||||
"quality": config.capture.quality,
|
||||
"resolution": config.capture.resolution,
|
||||
},
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue