Source code for rtiumaapy.command_provider

"""CommandProvider — server-side command interaction (Tier 0).

A ``CommandProvider`` receives commands addressed to it via a content-filtered
reader, manages ``CommandProviderSession`` instances for each active command,
and owns all ack/status/exec_status writers. Subclasses override async hooks
to implement domain-specific behavior.
"""

from __future__ import annotations

import asyncio
import logging
from typing import Dict, Optional, Tuple, Type

import rti.connextdds as dds
import rti.asyncio  # noqa: F401 — enables take_async

from rtiumaapy.base_service import BaseService
from rtiumaapy.dds_context import DDSContext
from rtiumaapy.guid_util import GUIDUtil
from rtiumaapy.timestamp import UmaaTimestamp
from rtiumaapy.validation import validate_message
from rtiumaapy.command_provider_session import (
    CommandProviderSession,
    CommandStatusEnum,
    CommandReasonEnum,
    _TERMINAL_STATES,
)

_logger = logging.getLogger(__name__)


[docs] class CommandProvider(BaseService): """Subclass and override hooks to implement command behavior. Base class owns ALL status publishing — subclasses never publish status directly. Args: ctx: The :class:`DDSContext` owning shared DDS infrastructure. service_name: Unique name for this service instance. command_type: IDL-generated command struct type. ack_type: IDL-generated ack report struct type. status_type: IDL-generated command status struct type. exec_status_type: Optional execution status struct type. command_topic: DDS topic name for commands. ack_topic: DDS topic name for ack. status_topic: DDS topic name for status. exec_status_topic: Optional DDS topic name for exec status. source_id: This provider's ``IdentifierType`` identity. """ def __init__( self, ctx: DDSContext, service_name: Optional[str] = None, *, command_type: Type, ack_type: Type, status_type: Type, command_topic: str, ack_topic: str, status_topic: str, exec_status_type: Type = None, exec_status_topic: str = None, source_id, ) -> None: super().__init__(ctx, service_name) self._command_type = command_type self._ack_type = ack_type self._status_type = status_type self._exec_status_type = exec_status_type self._source_id = source_id # Filtered reader for commands addressed to this provider. filter_expr = ( f"destination.parentID = &hex({GUIDUtil.to_hex(source_id.parentID.value)})" f" AND destination.id = &hex({GUIDUtil.to_hex(source_id.id.value)})" ) self._command_reader, self._command_cft = ctx.create_filtered_reader( command_type, command_topic, filter_expr, ) self._ack_writer = ctx.create_writer(ack_type, ack_topic) self._status_writer = ctx.create_writer(status_type, status_topic) self._exec_status_writer = None if exec_status_type is not None and exec_status_topic is not None: self._exec_status_writer = ctx.create_writer( exec_status_type, exec_status_topic) self._active_sessions: Dict[str, CommandProviderSession] = {} self._task: Optional[asyncio.Task] = None # ── Hooks (override in subclass) ──────────────────────────────────────
[docs] async def validate_command(self, command) -> Tuple[bool, str]: """Validate a command after ISSUED status and ack are published. Default implementation checks all fields against UMAA IDL constraints (range, enum, nested struct). Override for additional domain-specific logic. Return ``(True, "")`` to accept, ``(False, reason_string)`` to reject. """ valid, errors = validate_message(command) if not valid: return False, "; ".join(errors) return True, ""
[docs] async def on_updated( self, session: CommandProviderSession, previous_command, updated_command, ) -> None: """Called when a command update arrives for an active session."""
[docs] async def on_commanded(self, session: CommandProviderSession) -> None: """Called after COMMANDED status published, before EXECUTING."""
[docs] async def on_executing(self, session: CommandProviderSession) -> None: """Called after EXECUTING status published. Do the actual work here. Raise ``CommandHookError`` for domain-specific failures. """ raise NotImplementedError("Subclass must implement on_executing()")
[docs] async def on_complete(self, session: CommandProviderSession) -> None: """Called after COMPLETED status published."""
[docs] async def on_failed( self, session: CommandProviderSession, exception: Exception ) -> None: """Called after FAILED status published."""
[docs] async def on_terminal(self, session: CommandProviderSession) -> None: """Always called (finally block), regardless of outcome."""
# ── Event loop ────────────────────────────────────────────────────────
[docs] def start(self) -> None: """Start the command processing loop.""" if self._task is None or self._task.done(): self._task = asyncio.ensure_future(self._run())
async def _run(self) -> None: """Main event loop — receives commands and dispatches to sessions.""" try: async for sample in self._command_reader.take_async(): if sample.info.valid: await self._handle_valid_sample(sample.data) else: await self._handle_metadata_sample(sample.info) except asyncio.CancelledError: pass except Exception: _logger.exception("Provider %s command reader failed", self.service_name) raise async def _handle_valid_sample(self, command) -> None: """Route a valid command sample to session or rejection.""" session_id = str(command.sessionID) # Ack every valid command immediately self._publish_ack(command) if session_id in self._active_sessions: existing = self._active_sessions[session_id] if command.timeStamp.seconds > existing.command.timeStamp.seconds or \ (command.timeStamp.seconds == existing.command.timeStamp.seconds and command.timeStamp.nanoseconds > existing.command.timeStamp.nanoseconds): existing.set_new_command(command) # else: same or older timestamp → ignore return if self._active_sessions: # D3: Different sessionID while active → reject await self._reject_active(command) return # New session — publish ISSUED status synchronously with the ack # so that external observers (e.g. USTM) see the command and its first # status in the correct order, matching C++ provider behavior. session = CommandProviderSession(provider=self, command=command) self._active_sessions[session_id] = session session._transition_to(CommandStatusEnum.ISSUED, reason=CommandReasonEnum.SUCCEEDED) session._publish_status() session._task = asyncio.create_task(session.run()) async def _handle_metadata_sample(self, info: dds.SampleInfo) -> None: """Handle dispose / no-writers metadata.""" instance_state = info.state.instance_state if instance_state in (dds.InstanceState.NOT_ALIVE_DISPOSED, dds.InstanceState.NOT_ALIVE_NO_WRITERS): try: key_data = self._command_reader.key_value( info.instance_handle) session_id = str(key_data.sessionID) if session_id in self._active_sessions: await self._active_sessions[session_id].cancel() except Exception: _logger.debug( "Provider %s: key_value recovery failed for disposed", self.service_name, exc_info=True) def _publish_ack(self, command) -> None: """Publish ack echoing the command back.""" try: ack = self._ack_type() ack.command = command ack.source = self._source_id ack.sessionID = command.sessionID UmaaTimestamp.set_timestamp(ack) self._ack_writer.write(ack) except Exception: _logger.warning("Provider %s: ack publish failed", self.service_name, exc_info=True) async def _reject_active(self, command) -> None: """Reject a command via ephemeral session (D3). ISSUED → COMMANDED → FAILED(RESOURCE_REJECTED). """ temp = CommandProviderSession(self, command) try: temp._transition_to(CommandStatusEnum.ISSUED, reason=CommandReasonEnum.SUCCEEDED) temp._publish_status() temp._transition_to(CommandStatusEnum.COMMANDED, reason=CommandReasonEnum.SUCCEEDED) temp._publish_status() temp._transition_to( CommandStatusEnum.FAILED, reason=CommandReasonEnum.RESOURCE_REJECTED) temp._publish_status() except Exception: _logger.warning("Provider %s: rejection sequence failed", self.service_name, exc_info=True) finally: temp._dispose_provider_instances() # ── Lifecycle ─────────────────────────────────────────────────────────
[docs] async def close(self) -> None: """Fail active sessions, cancel _run. Entity cleanup is deferred to DDSContext.""" # Cancel the _run loop first if self._task is not None and not self._task.done(): self._task.cancel() try: await self._task except asyncio.CancelledError: pass except Exception: _logger.exception( "Provider %s: _run task failed during close", self.service_name, ) # Fail all active sessions tasks = [] for session in list(self._active_sessions.values()): await session.fail() if session._task is not None: tasks.append(session._task) if tasks: await asyncio.gather(*tasks, return_exceptions=True)