Source code for rtiumaapy.command_consumer

"""CommandConsumer — client-side command interaction (Tier 0).

A ``CommandConsumer`` sends commands, receives ack/status/exec_status via
content-filtered readers, and delivers updates via subclass-override hooks.
One active session at a time.
"""

from __future__ import annotations

import asyncio
import logging
from typing import Optional, Type

import rti.connextdds as dds
import rti.asyncio  # noqa: F401 — enables take_async

from rtiumaapy.base_service import BaseService
from rtiumaapy.dds_context import DDSContext
from rtiumaapy.guid_util import GUIDUtil
from rtiumaapy.timestamp import UmaaTimestamp
from rtiumaapy.validation import validate_message
from rtiumaapy.command_provider_session import CommandStatusEnum

_logger = logging.getLogger(__name__)

_TERMINAL_STATUSES = {
    CommandStatusEnum.COMPLETED,
    CommandStatusEnum.FAILED,
    CommandStatusEnum.CANCELED,
}


[docs] class CommandConsumer(BaseService): """Subclass and override hooks to react to command lifecycle events. Args: ctx: The :class:`DDSContext` owning shared DDS infrastructure. service_name: Unique name for this service instance. command_type: IDL-generated command struct type. ack_type: IDL-generated ack report struct type. status_type: IDL-generated command status struct type. exec_status_type: Optional execution status struct type. command_topic: DDS topic name for commands. ack_topic: DDS topic name for ack. status_topic: DDS topic name for status. exec_status_topic: Optional DDS topic name for exec status. source_id: This consumer's ``IdentifierType`` identity. destination_id: Target provider's ``IdentifierType`` identity. """ def __init__( self, ctx: DDSContext, service_name: Optional[str] = None, *, command_type: Type, ack_type: Type, status_type: Type, command_topic: str, ack_topic: str, status_topic: str, exec_status_type: Type = None, exec_status_topic: str = None, source_id=None, destination_id=None, ) -> None: super().__init__(ctx, service_name) self._command_type = command_type self._source_id = source_id self._destination_id = destination_id self._command_writer = ctx.create_writer(command_type, command_topic) # Readers start with blocking filter ("1 = 0") self._ack_reader, self._ack_cft = ctx.create_filtered_reader( ack_type, ack_topic, "1 = 0") self._status_reader, self._status_cft = ctx.create_filtered_reader( status_type, status_topic, "1 = 0") self._exec_status_reader = None self._exec_status_cft = None if exec_status_type is not None and exec_status_topic is not None: self._exec_status_reader, self._exec_status_cft = \ ctx.create_filtered_reader( exec_status_type, exec_status_topic, "1 = 0") # Active session state self._session_id: Optional[bytes] = None self._session_command = None self._task: Optional[asyncio.Task] = None # ── Hooks (override in subclass) ──────────────────────────────────────
[docs] async def on_status(self, session_id: bytes, status) -> None: """Called for every status update from the provider. Do not call ``cancel()`` from this hook — terminal statuses are detected automatically and trigger ``_end_session()`` after this hook returns. """
[docs] async def on_ack(self, session_id: bytes, ack) -> None: """Called when the provider echoes the command acknowledgment."""
[docs] async def on_exec_status(self, session_id: bytes, exec_status) -> None: """Called when the provider publishes execution progress."""
[docs] async def on_terminal(self, session_id: bytes, status) -> None: """Called after the session closes. Args: session_id: The session that ended. status: Terminal status sample, or None for cancel/crash/shutdown. """
# ── Discovery ────────────────────────────────────────────────────────── @property def has_matched_provider(self) -> bool: """True if at least one provider is subscribed to the command topic.""" return self._command_writer.publication_matched_status.current_count > 0
[docs] async def wait_for_discovery(self, timeout: float = 30.0) -> bool: """Block until a provider subscribes to the command topic. Returns True if a provider was discovered, False on timeout. """ deadline = asyncio.get_event_loop().time() + timeout while not self.has_matched_provider: remaining = deadline - asyncio.get_event_loop().time() if remaining <= 0: return False await asyncio.sleep(min(0.25, remaining)) return True
# ── Send / Cancel ─────────────────────────────────────────────────────
[docs] async def send(self, command, session_id: bytes = None) -> bytes: """Send a command or update. Returns the session_id handle. If ``session_id`` is None, starts a new session. If provided, sends an update to the existing session (D39). All header fields are auto-stamped (D47). Raises: RuntimeError: If starting a new session while one is active. RuntimeError: If session_id doesn't match the active session. """ if session_id is None: # New session if self._session_id is not None: raise RuntimeError("Consumer already has an active session") from rtiumaapy.datamodel.Measurements import \ UMAA_Common_Measurement_NumericGUID as NumericGUID session_id_bytes = GUIDUtil.generate() session_id_guid = NumericGUID( value=dds.Uint8Seq(session_id_bytes)) command.source = self._source_id command.destination = self._destination_id command.sessionID = session_id_guid UmaaTimestamp.set_timestamp(command) self._session_id = session_id_bytes self._session_command = command self._set_session_filter(session_id_bytes) await asyncio.sleep(0.01) # let CFT filter propagate (C78) self._command_writer.write(command) return session_id_bytes else: # Update existing session if self._session_id is None or session_id != self._session_id: raise RuntimeError("No active session with that session_id") command.source = self._source_id command.destination = self._destination_id command.sessionID = self._session_command.sessionID UmaaTimestamp.set_timestamp(command) self._session_command = command self._command_writer.write(command) return session_id
[docs] async def cancel(self) -> None: """Cancel the active session — dispose command and clean up (D50).""" if self._session_id is None: return await self._end_session(None)
# ── Event loop ────────────────────────────────────────────────────────
[docs] def start(self) -> None: """Start the reader dispatch loops.""" if self._task is None or self._task.done(): self._task = asyncio.ensure_future(self._run())
async def _run(self) -> None: """Multiplex all consumer readers via gather.""" tasks = [ self._read_status_loop(), self._read_ack_loop(), ] if self._exec_status_reader is not None: tasks.append(self._read_exec_status_loop()) try: await asyncio.gather(*tasks) except asyncio.CancelledError: pass except Exception: _logger.exception("Consumer %s reader loop failed", self.service_name) raise async def _read_status_loop(self) -> None: """Status reader dispatch — on_status + terminal detection.""" async for sample in self._status_reader.take_async(): if sample.info.valid: if self._session_id is not None: valid, errors = validate_message(sample.data) if not valid: _logger.warning( "Consumer %s received invalid status: %s", self.service_name, "; ".join(errors), ) session_id = self._session_id try: await self.on_status(session_id, sample.data) except Exception: _logger.exception( "Consumer %s on_status hook error", self.service_name, ) if sample.data.commandStatus in _TERMINAL_STATUSES: await self._end_session(sample.data) else: instance_state = sample.info.state.instance_state if instance_state == dds.InstanceState.NOT_ALIVE_NO_WRITERS: if self._session_id is not None: await self._end_session(None) async def _read_ack_loop(self) -> None: """Ack reader dispatch — on_ack hook.""" async for sample in self._ack_reader.take_async(): if sample.info.valid: if self._session_id is not None: valid, errors = validate_message(sample.data) if not valid: _logger.warning( "Consumer %s received invalid ack: %s", self.service_name, "; ".join(errors), ) try: await self.on_ack(self._session_id, sample.data) except Exception: _logger.exception( "Consumer %s on_ack hook error", self.service_name, ) async def _read_exec_status_loop(self) -> None: """Exec status reader dispatch — on_exec_status hook.""" async for sample in self._exec_status_reader.take_async(): if sample.info.valid: if self._session_id is not None: valid, errors = validate_message(sample.data) if not valid: _logger.warning( "Consumer %s received invalid exec_status: %s", self.service_name, "; ".join(errors), ) try: await self.on_exec_status( self._session_id, sample.data) except Exception: _logger.exception( "Consumer %s on_exec_status hook error", self.service_name, ) # ── Session lifecycle ───────────────────────────────────────────────── async def _end_session(self, status) -> None: """Sole session cleanup owner (C35). 1. Dispose command instance 2. Clear session state 3. Reset CFT filters to ``"1=0"`` 4. Call ``on_terminal()`` """ if self._session_id is None: return # already cleaned up (C66) self._dispose_command_instance() session_id = self._session_id self._session_id = None self._session_command = None self._reset_filters() try: await self.on_terminal(session_id, status) except Exception: _logger.exception( "Consumer %s on_terminal hook error", self.service_name, ) def _dispose_command_instance(self) -> None: """Dispose consumer's command instance (ICD §5.1.5).""" if self._session_command is not None: try: ih = self._command_writer.lookup_instance( self._session_command) if ih != dds.InstanceHandle.nil(): self._command_writer.dispose_instance(ih) except Exception: _logger.debug("Consumer %s: command dispose failed", self.service_name) def _set_session_filter(self, session_id: bytes) -> None: """Set CFTs to receive responses for this session.""" expr = f"sessionID = &hex({GUIDUtil.to_hex(session_id)})" filt = dds.Filter(expr) self._ack_cft.set_filter(filt) self._status_cft.set_filter(filt) if self._exec_status_cft is not None: self._exec_status_cft.set_filter(filt) def _reset_filters(self) -> None: """Block all data — no active session.""" filt = dds.Filter("1 = 0") self._ack_cft.set_filter(filt) self._status_cft.set_filter(filt) if self._exec_status_cft is not None: self._exec_status_cft.set_filter(filt) # ── Lifecycle ─────────────────────────────────────────────────────────
[docs] async def close(self) -> None: """End active session, cancel _run. Entity cleanup is deferred to DDSContext.""" if self._session_id is not None: await self._end_session(None) if self._task is not None and not self._task.done(): self._task.cancel() try: await self._task except asyncio.CancelledError: pass