346 lines
12 KiB
Python
346 lines
12 KiB
Python
|
|
"""
|
||
|
|
Vision API for screen capture and streaming.
|
||
|
|
Currently implements WebSocket streaming for the map panel.
|
||
|
|
OCR and advanced vision features are gated behind configuration flags for V2.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
from typing import Optional, Dict, Any
|
||
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, HTTPException, Depends
|
||
|
|
from fastapi.responses import JSONResponse
|
||
|
|
from pydantic import BaseModel, Field
|
||
|
|
|
||
|
|
from ..core.config import get_config
|
||
|
|
from ..core.screen import get_screen_capture, capture_region
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
router = APIRouter()
|
||
|
|
|
||
|
|
|
||
|
|
class StreamConfig(BaseModel):
|
||
|
|
"""Configuration for video streaming."""
|
||
|
|
|
||
|
|
fps: int = Field(30, ge=1, le=60, description="Frames per second")
|
||
|
|
quality: int = Field(85, ge=1, le=100, description="JPEG quality")
|
||
|
|
resolution: str = Field("720p", description="Stream resolution preset")
|
||
|
|
region_x: int = Field(0, ge=0, description="Region X coordinate")
|
||
|
|
region_y: int = Field(0, ge=0, description="Region Y coordinate")
|
||
|
|
region_width: int = Field(640, ge=1, description="Region width")
|
||
|
|
region_height: int = Field(480, ge=1, description="Region height")
|
||
|
|
|
||
|
|
|
||
|
|
class RegionDefinition(BaseModel):
|
||
|
|
"""Definition of a screen region for capture."""
|
||
|
|
|
||
|
|
name: str = Field(..., description="Region name")
|
||
|
|
x: int = Field(..., ge=0, description="X coordinate")
|
||
|
|
y: int = Field(..., ge=0, description="Y coordinate")
|
||
|
|
width: int = Field(..., gt=0, description="Width")
|
||
|
|
height: int = Field(..., gt=0, description="Height")
|
||
|
|
description: Optional[str] = Field(None, description="Region description")
|
||
|
|
|
||
|
|
|
||
|
|
class CaptureRequest(BaseModel):
|
||
|
|
"""Request to capture a screen region."""
|
||
|
|
|
||
|
|
x: int = Field(0, ge=0, description="X coordinate")
|
||
|
|
y: int = Field(0, ge=0, description="Y coordinate")
|
||
|
|
width: int = Field(640, gt=0, description="Width")
|
||
|
|
height: int = Field(480, gt=0, description="Height")
|
||
|
|
format: str = Field("base64", description="Output format (base64 or raw)")
|
||
|
|
quality: int = Field(85, ge=1, le=100, description="JPEG quality")
|
||
|
|
|
||
|
|
|
||
|
|
class StreamManager:
|
||
|
|
"""Manages WebSocket streaming sessions."""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self.active_streams: Dict[str, WebSocket] = {}
|
||
|
|
self.stream_configs: Dict[str, StreamConfig] = {}
|
||
|
|
|
||
|
|
async def add_stream(self, client_id: str, websocket: WebSocket, config: StreamConfig):
|
||
|
|
"""Add a new streaming session."""
|
||
|
|
await websocket.accept()
|
||
|
|
self.active_streams[client_id] = websocket
|
||
|
|
self.stream_configs[client_id] = config
|
||
|
|
logger.info(f"Stream started for client {client_id}")
|
||
|
|
|
||
|
|
async def remove_stream(self, client_id: str):
|
||
|
|
"""Remove a streaming session."""
|
||
|
|
if client_id in self.active_streams:
|
||
|
|
del self.active_streams[client_id]
|
||
|
|
del self.stream_configs[client_id]
|
||
|
|
logger.info(f"Stream stopped for client {client_id}")
|
||
|
|
|
||
|
|
async def stream_frame(self, client_id: str, frame_data: str):
|
||
|
|
"""Send a frame to a specific client."""
|
||
|
|
if client_id in self.active_streams:
|
||
|
|
websocket = self.active_streams[client_id]
|
||
|
|
try:
|
||
|
|
await websocket.send_text(frame_data)
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to send frame to client {client_id}: {e}")
|
||
|
|
await self.remove_stream(client_id)
|
||
|
|
|
||
|
|
|
||
|
|
# Global stream manager
|
||
|
|
stream_manager = StreamManager()
|
||
|
|
|
||
|
|
|
||
|
|
@router.websocket("/ws/stream")
|
||
|
|
async def stream_video(websocket: WebSocket):
|
||
|
|
"""
|
||
|
|
WebSocket endpoint for streaming video of a screen region.
|
||
|
|
|
||
|
|
The client should send a JSON message with stream configuration:
|
||
|
|
{
|
||
|
|
"action": "start",
|
||
|
|
"config": {
|
||
|
|
"fps": 30,
|
||
|
|
"quality": 85,
|
||
|
|
"resolution": "720p",
|
||
|
|
"region_x": 0,
|
||
|
|
"region_y": 0,
|
||
|
|
"region_width": 640,
|
||
|
|
"region_height": 480
|
||
|
|
}
|
||
|
|
}
|
||
|
|
"""
|
||
|
|
config = get_config()
|
||
|
|
client_id = id(websocket)
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Accept the WebSocket connection
|
||
|
|
await websocket.accept()
|
||
|
|
logger.info(f"WebSocket connected: client {client_id}")
|
||
|
|
|
||
|
|
# Wait for initial configuration message
|
||
|
|
data = await websocket.receive_text()
|
||
|
|
message = json.loads(data)
|
||
|
|
|
||
|
|
if message.get("action") != "start":
|
||
|
|
await websocket.send_json({"error": "First message must be start action"})
|
||
|
|
await websocket.close()
|
||
|
|
return
|
||
|
|
|
||
|
|
# Parse stream configuration
|
||
|
|
stream_config = StreamConfig(**message.get("config", {}))
|
||
|
|
|
||
|
|
# Override with server config if needed
|
||
|
|
stream_config.fps = min(stream_config.fps, config.capture.fps)
|
||
|
|
stream_config.quality = config.capture.quality
|
||
|
|
|
||
|
|
# Add to stream manager
|
||
|
|
await stream_manager.add_stream(str(client_id), websocket, stream_config)
|
||
|
|
|
||
|
|
# Send confirmation
|
||
|
|
await websocket.send_json({"type": "config", "data": stream_config.model_dump()})
|
||
|
|
|
||
|
|
# Get screen capture instance
|
||
|
|
capture = get_screen_capture()
|
||
|
|
|
||
|
|
# Calculate frame interval
|
||
|
|
frame_interval = 1.0 / stream_config.fps
|
||
|
|
|
||
|
|
# Streaming loop
|
||
|
|
while True:
|
||
|
|
try:
|
||
|
|
# Capture the region
|
||
|
|
image = capture.capture_region(
|
||
|
|
stream_config.region_x,
|
||
|
|
stream_config.region_y,
|
||
|
|
stream_config.region_width,
|
||
|
|
stream_config.region_height,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Resize if needed
|
||
|
|
if stream_config.resolution != "native":
|
||
|
|
target_width, target_height = capture.get_resolution_preset(stream_config.resolution)
|
||
|
|
# Only resize if different from capture size
|
||
|
|
if target_width != stream_config.region_width or target_height != stream_config.region_height:
|
||
|
|
image = capture.resize_image(image, width=target_width)
|
||
|
|
|
||
|
|
# Convert to base64
|
||
|
|
base64_image = capture.image_to_base64(image, quality=stream_config.quality)
|
||
|
|
|
||
|
|
# Send frame
|
||
|
|
frame_data = json.dumps(
|
||
|
|
{"type": "frame", "data": base64_image, "timestamp": asyncio.get_event_loop().time()}
|
||
|
|
)
|
||
|
|
|
||
|
|
await websocket.send_text(frame_data)
|
||
|
|
|
||
|
|
# Wait for next frame
|
||
|
|
await asyncio.sleep(frame_interval)
|
||
|
|
|
||
|
|
except WebSocketDisconnect:
|
||
|
|
logger.info(f"Client {client_id} disconnected")
|
||
|
|
break
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error in stream loop for client {client_id}: {e}")
|
||
|
|
await websocket.send_json({"type": "error", "message": str(e)})
|
||
|
|
break
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"WebSocket error for client {client_id}: {e}")
|
||
|
|
finally:
|
||
|
|
await stream_manager.remove_stream(str(client_id))
|
||
|
|
logger.info(f"WebSocket closed: client {client_id}")
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/capture", response_model=Dict[str, Any])
|
||
|
|
async def capture_screen_region(request: CaptureRequest):
|
||
|
|
"""
|
||
|
|
Capture a single frame from a screen region.
|
||
|
|
|
||
|
|
This is useful for testing or getting a single snapshot.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
capture = get_screen_capture()
|
||
|
|
|
||
|
|
# Capture the region
|
||
|
|
image = capture.capture_region(request.x, request.y, request.width, request.height)
|
||
|
|
|
||
|
|
if request.format == "base64":
|
||
|
|
# Convert to base64
|
||
|
|
base64_image = capture.image_to_base64(image, quality=request.quality)
|
||
|
|
|
||
|
|
return {
|
||
|
|
"success": True,
|
||
|
|
"format": "base64",
|
||
|
|
"data": base64_image,
|
||
|
|
"width": request.width,
|
||
|
|
"height": request.height,
|
||
|
|
}
|
||
|
|
else:
|
||
|
|
return {"success": False, "error": f"Unsupported format: {request.format}"}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Failed to capture region: {e}")
|
||
|
|
raise HTTPException(status_code=500, detail=str(e))
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/regions")
|
||
|
|
async def get_regions():
|
||
|
|
"""
|
||
|
|
Get predefined screen regions for capture.
|
||
|
|
|
||
|
|
Returns common regions like map area, club indicator, etc.
|
||
|
|
"""
|
||
|
|
config = get_config()
|
||
|
|
|
||
|
|
# Predefined regions for GSPro UI elements
|
||
|
|
regions = {
|
||
|
|
"map": {
|
||
|
|
"name": "Map Panel",
|
||
|
|
"x": config.capture.region_x,
|
||
|
|
"y": config.capture.region_y,
|
||
|
|
"width": config.capture.region_width,
|
||
|
|
"height": config.capture.region_height,
|
||
|
|
"description": "GSPro mini-map or expanded map view",
|
||
|
|
},
|
||
|
|
"club": {
|
||
|
|
"name": "Club Indicator",
|
||
|
|
"x": 50,
|
||
|
|
"y": 200,
|
||
|
|
"width": 200,
|
||
|
|
"height": 100,
|
||
|
|
"description": "Current club selection display",
|
||
|
|
},
|
||
|
|
"shot_info": {
|
||
|
|
"name": "Shot Information",
|
||
|
|
"x": 50,
|
||
|
|
"y": 50,
|
||
|
|
"width": 300,
|
||
|
|
"height": 150,
|
||
|
|
"description": "Shot distance and trajectory information",
|
||
|
|
},
|
||
|
|
"scorecard": {
|
||
|
|
"name": "Scorecard",
|
||
|
|
"x": 400,
|
||
|
|
"y": 100,
|
||
|
|
"width": 800,
|
||
|
|
"height": 600,
|
||
|
|
"description": "Scorecard overlay when visible",
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
return {"regions": regions, "total": len(regions)}
|
||
|
|
|
||
|
|
|
||
|
|
@router.post("/regions/{region_name}")
|
||
|
|
async def update_region(region_name: str, region: RegionDefinition):
|
||
|
|
"""
|
||
|
|
Update or create a screen region definition.
|
||
|
|
|
||
|
|
This allows users to define custom regions for their setup.
|
||
|
|
"""
|
||
|
|
config = get_config()
|
||
|
|
|
||
|
|
if region_name == "map":
|
||
|
|
# Update the map region in config
|
||
|
|
config.capture.region_x = region.x
|
||
|
|
config.capture.region_y = region.y
|
||
|
|
config.capture.region_width = region.width
|
||
|
|
config.capture.region_height = region.height
|
||
|
|
config.save()
|
||
|
|
|
||
|
|
return {"success": True, "message": f"Region '{region_name}' updated", "region": region.model_dump()}
|
||
|
|
else:
|
||
|
|
# For now, only map region is persisted
|
||
|
|
# V2 will add support for custom regions
|
||
|
|
return {"success": False, "message": "Custom regions not yet supported (V2 feature)"}
|
||
|
|
|
||
|
|
|
||
|
|
# OCR endpoints - gated behind vision config flag
|
||
|
|
@router.post("/ocr")
|
||
|
|
async def perform_ocr(request: CaptureRequest):
|
||
|
|
"""
|
||
|
|
Perform OCR on a screen region (V2 feature).
|
||
|
|
|
||
|
|
This endpoint is only available when vision features are enabled.
|
||
|
|
"""
|
||
|
|
config = get_config()
|
||
|
|
|
||
|
|
if not config.vision.enabled:
|
||
|
|
raise HTTPException(status_code=403, detail="Vision features are not enabled. This is a V2 feature.")
|
||
|
|
|
||
|
|
# OCR implementation will go here in V2
|
||
|
|
return {"success": False, "message": "OCR features coming in V2", "vision_enabled": config.vision.enabled}
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/markers")
|
||
|
|
async def get_markers():
|
||
|
|
"""
|
||
|
|
Get visual markers for template matching (V2 feature).
|
||
|
|
|
||
|
|
This endpoint is only available when vision features are enabled.
|
||
|
|
"""
|
||
|
|
config = get_config()
|
||
|
|
|
||
|
|
if not config.vision.enabled:
|
||
|
|
raise HTTPException(status_code=403, detail="Vision features are not enabled. This is a V2 feature.")
|
||
|
|
|
||
|
|
# Marker management will go here in V2
|
||
|
|
return {"markers": [], "message": "Marker features coming in V2", "vision_enabled": config.vision.enabled}
|
||
|
|
|
||
|
|
|
||
|
|
@router.get("/status")
|
||
|
|
async def get_vision_status():
|
||
|
|
"""Get the status of vision features."""
|
||
|
|
config = get_config()
|
||
|
|
|
||
|
|
return {
|
||
|
|
"streaming_enabled": True,
|
||
|
|
"ocr_enabled": config.vision.enabled,
|
||
|
|
"markers_enabled": config.vision.enabled,
|
||
|
|
"active_streams": len(stream_manager.active_streams),
|
||
|
|
"capture_config": {
|
||
|
|
"fps": config.capture.fps,
|
||
|
|
"quality": config.capture.quality,
|
||
|
|
"resolution": config.capture.resolution,
|
||
|
|
},
|
||
|
|
}
|