Source code for kani.ext.realtime.session

import asyncio
import logging
from typing import Any, Awaitable, Callable, TypeVar

import websockets
from websockets.asyncio.client import ClientConnection, connect

from ._internal import get_server_event_handlers, server_event_handler
from .events import ClientEvent, ServerEvent, server as server_events
from .models import ConversationItem, RealtimeResponse, SessionConfig

log = logging.getLogger(__name__)
ServerEventT = TypeVar("ServerEventT", bound=ServerEvent)


[docs] class RealtimeSession: """This is an internal object used to manage the state of the OpenAI Realtime session.""" def __init__( self, api_key: str, model="gpt-4o-realtime-preview-2024-10-01", *, ws_base: str = "wss://api.openai.com/v1/realtime", headers: dict = None, # organization: str = None, # todo is this supported? # :param organization: The OpenAI organization to use in requests. By default, the org ID would be read from # the `OPENAI_ORG_ID` environment variable (defaults to the API key's default org if not set). **generation_args, ): """ :param api_key: Your OpenAI API key. :param model: The id of the realtime model to use (default "gpt-4o-realtime-preview-2024-10-01"). :param ws_base: The base WebSocket URL to connect to. :param headers: A dict of HTTP headers to include with each request. :param client: An instance of ``httpx.AsyncClient`` (for reusing the same client in multiple engines). :param generation_args: The arguments to pass to the ``response.create`` call with each request. See https://platform.openai.com/docs/api-reference/realtime-client-events/response/create for a full list of params. Specifically, these arguments will be passed as the ``response`` key. """ # default headers headers.setdefault("Authorization", f"Bearer {api_key}") headers.setdefault("OpenAI-Beta", "realtime=v1") # client config self.ws_base = ws_base self.headers = headers self.model = model self.generation_args = generation_args # state self.session_config: SessionConfig | None = None self.session_id: str | None = None self.conversation_id: str | None = None self.responses: dict[str, RealtimeResponse] = {} self.conversation_items: dict[str, ConversationItem] = {} # ws self.ws: ClientConnection | None = None self._ws_connected = asyncio.Event() self._session_created = asyncio.Event() self.listeners = [] self.ws_task = None # event handlers self._server_event_handlers = get_server_event_handlers(self) # ==== lifecycle ====
[docs] async def connect(self): """ Connect to the WS, begin a task for event handling, and init the session. You should usually call :meth:`.OpenAIRealtimeKani.connect` instead of this. """ if self.ws_task is None: self.ws_task = asyncio.create_task(self._ws_task(), name="realtime-ws") await self._ws_connected.wait() await self._session_created.wait()
async def close(self): if self.ws_task is not None: self.ws_task.cancel() # closes on cancel # ==== iface ====
[docs] async def send(self, event: ClientEvent): """Send a client event to the websocket.""" if self.ws is None: raise RuntimeError("Websocket is not yet initialized - call connect() first") log.debug(f">>> {event!r}") data = event.model_dump_json() await self.ws.send(data)
# ==== events ====
[docs] def add_listener(self, callback: Callable[[ServerEvent], Awaitable[Any]]): """ Add a listener which is called for every event received from the WS. The listener must be an asynchronous function that takes in an event in a single argument. """ self.listeners.append(callback)
[docs] def remove_listener(self, callback): """Remove a listener added by :meth:`add_listener`.""" self.listeners.remove(callback)
[docs] async def wait_for( self, event_type: str, predicate: Callable[[ServerEventT], bool] = None, timeout: int = 60 ) -> ServerEventT: """Wait for the next event of a given type, and return it.""" future = asyncio.get_running_loop().create_future() async def waiter(e: ServerEvent): if e.type == event_type: if predicate is None or predicate(e): future.set_result(e) try: self.add_listener(waiter) return await asyncio.wait_for(future, timeout) finally: self.remove_listener(waiter)
async def _ws_task(self): """Main websocket receive loop.""" try: async with connect(f"{self.ws_base}?model={self.model}", additional_headers=self.headers) as self.ws: self._ws_connected.set() async for data in self.ws: # noinspection PyBroadException try: event = ServerEvent.model_validate_json(data) log.debug(f"<<< {event!r}") # process our event first, always await self._handle_server_event(event) # get listeners, call them - listeners can use the result of the processing if needed await asyncio.gather(*(callback(event) for callback in self.listeners), return_exceptions=True) except websockets.ConnectionClosedError as e: log.error(f"WS connection closed unexpectedly: {e}") except asyncio.CancelledError: return except Exception: log.exception("Exception when handling WS event:") except asyncio.CancelledError: return finally: self._ws_connected.clear() self.ws = None # ==== ws event handlers ==== async def _handle_server_event(self, event: ServerEvent): """ Main entrypoint for received server events. Will always be fully processed before WS events are dispatched to consumers to allow consumers to read from state instead of updating their state. """ handler = self._server_event_handlers.get(event.type) if handler is None: # warnings.warn(f"A server event with type {event.type!r} is being unhandled: {event!r}") return await handler(event) @server_event_handler("error") async def _handle_error(self, event: server_events.Error): log.error(event.error) @server_event_handler("session.created") async def _handle_session_created(self, event: server_events.SessionCreated): self._session_created.set() self.session_id = event.session.id self.session_config = event.session @server_event_handler("session.updated") async def _handle_session_updated(self, event: server_events.SessionUpdated): self.session_id = event.session.id self.session_config = event.session @server_event_handler("conversation.created") async def _handle_conversation_created(self, event: server_events.ConversationCreated): self.conversation_id = event.conversation.id @server_event_handler("conversation.item.created") async def _handle_conversation_item_created(self, event: server_events.ConversationItemCreated): item_id = event.item.id self.conversation_items[item_id] = event.item @server_event_handler("conversation.item.input_audio_transcription.completed") async def _handle_conversation_item_input_audio_transcription_completed( self, event: server_events.ConversationItemInputAudioTranscriptionCompleted ): content = self.get_item_content(event.item_id, event.content_index) content.transcript = event.transcript.strip() @server_event_handler("conversation.item.input_audio_transcription.failed") async def _handle_conversation_item_input_audio_transcription_failed( self, event: server_events.ConversationItemInputAudioTranscriptionFailed ): content = self.get_item_content(event.item_id, event.content_index) content.transcript = "[transcript failed]" # todo log.warning(f"Audio transcription failed: {event.error}") @server_event_handler("conversation.item.truncated") async def _handle_conversation_item_truncated(self, event: server_events.ConversationItemTruncated): pass @server_event_handler("conversation.item.deleted") async def _handle_conversation_item_deleted(self, event: server_events.ConversationItemDeleted): self.conversation_items.pop(event.item_id, None) @server_event_handler("input_audio_buffer.committed") async def _handle_input_audio_buffer_committed(self, event: server_events.InputAudioBufferCommitted): pass # todo create an in progress item id? @server_event_handler("input_audio_buffer.cleared") async def _handle_input_audio_buffer_cleared(self, event: server_events.InputAudioBufferCleared): pass @server_event_handler("input_audio_buffer.speech_started") async def _handle_input_audio_buffer_speech_started(self, event: server_events.InputAudioBufferSpeechStarted): pass # todo create an in progress item id? @server_event_handler("input_audio_buffer.speech_stopped") async def _handle_input_audio_buffer_speech_stopped(self, event: server_events.InputAudioBufferSpeechStopped): pass # todo end an in progress item id? @server_event_handler("response.created") async def _handle_response_created(self, event: server_events.ResponseCreated): self.responses[event.response.id] = event.response for item in event.response.output: self.conversation_items[item.id] = item @server_event_handler("response.done") async def _handle_response_done(self, event: server_events.ResponseDone): self.responses[event.response.id] = event.response for item in event.response.output: self.conversation_items[item.id] = item @server_event_handler("rate_limits.updated") async def _handle_rate_limits_updated(self, event: server_events.RateLimitsUpdated): pass # todo # ==== helpers ==== def get_item_content(self, item_id: str, content_index: int): item = self.conversation_items.get(item_id) if item is None: log.warning(f"Got event referencing item that does not exist: {item_id}") return if content_index < 0 or content_index >= len(item.content): log.warning( f"Got event referencing item content out of bounds (len {len(item.content)}, got {content_index})" ) return return item.content[content_index]