import asyncio
import base64
import collections
import contextlib
import logging
import os
import warnings
from typing import AsyncIterable, Awaitable, Callable
from kani import AIFunction, ChatMessage, ExceptionHandleResult, Kani, ToolCall
from kani.engines.base import BaseCompletion, BaseEngine, Completion
from kani.exceptions import FunctionCallException
from kani.models import ChatRole, FunctionCall, QueryType
from kani.streaming import DummyStream, StreamManager
from . import interop, models as oaimodels
from .events import client as client_events, server as server_events
from .session import RealtimeSession
log = logging.getLogger(__name__)
class DummyEngine(BaseEngine):
max_context_size = 128000
def message_len(self, message: ChatMessage) -> int:
return len(message.text) // 4
def function_token_reserve(self, functions: list[AIFunction]) -> int:
return 0
async def predict(
self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams
) -> BaseCompletion:
raise NotImplementedError
[docs]
class OpenAIRealtimeKani(Kani):
r"""
In addition to all of :class:`kani.Kani`\ 's method, the OpenAIRealtimeKani provides the following two methods
for interacting with the realtime API.
"""
def __init__(
self,
# realtime session args
api_key: str = None,
model="gpt-4o-realtime-preview-2024-10-01",
*,
ws_base: str = "wss://api.openai.com/v1/realtime",
headers: dict = None,
# kani args
system_prompt: str = None,
chat_history: list[ChatMessage] = None,
always_included_messages: list[ChatMessage] = None,
**generation_args,
):
"""
:param api_key: Your OpenAI API key. By default, the API key will be read from the `OPENAI_API_KEY` environment
variable.
:param model: The id of the realtime model to use (default "gpt-4o-realtime-preview-2024-10-01").
:param system_prompt: The system prompt to provide to the LM. The prompt *will* be included in chat_history.
.. note::
For interacting with the Realtime API, you may instead wish to provide session instructions by providing
the ``instructions`` key to :class:`.SessionConfig` in your :meth:`connect` call.
:param chat_history: The chat history to start with (not including system prompt or always included messages),
e.g. for few-shot prompting. By default, each kani starts with a new conversation session.
:param always_included_messages: Prepended to ``chat_history``.
.. warning::
Unlike normal Kanis, due to the server-managed nature of the OpenAI realtime API, messages marked as
always included may not always be included by the server. These messages will instead be prepended to
any ``chat_history`` and *will* be included in the ``chat_history`` attribute.
:param ws_base: The base WebSocket URL to connect to (default "wss://api.openai.com/v1/realtime").
:param headers: A dict of HTTP headers to include with each request.
:param client: An instance of ``httpx.AsyncClient`` (for reusing the same client in multiple engines).
:param generation_args: The arguments to pass to the ``response.create`` call with each request. See
https://platform.openai.com/docs/api-reference/realtime-client-events/response/create for a full list of
params. Specifically, these arguments will be passed as the ``response`` key.
"""
if headers is None:
headers = {}
if api_key is None:
api_key = os.getenv("OPENAI_API_KEY")
self._has_connected = False
self._always_included_messages = always_included_messages
self._chat_history = chat_history
Kani.__init__(
self,
engine=DummyEngine(),
system_prompt=system_prompt,
always_included_messages=always_included_messages,
chat_history=chat_history,
)
self.lock = contextlib.nullcontext()
self.session = RealtimeSession(
api_key=api_key, model=model, ws_base=ws_base, headers=headers, **generation_args
)
"""The underlying state of the OpenAI Realtime API. Used for lower-level API operations."""
# ===== lifecycle =====
[docs]
async def connect(self, session_config: oaimodels.SessionConfig = None):
"""Connect to the WS and update the internal state until the engine is closed."""
if self._has_connected:
raise RuntimeError("This RealtimeKani has already connected to the socket.")
if session_config is None:
# we want input_audio_transcription to be on by default - see models for default config
session_config = oaimodels.SessionConfig()
self._has_connected = True
await self.session.connect()
# configure tools
if session_config:
tool_defs = session_config.tools + list(map(interop.ai_function_to_tool, self.functions.values()))
session_config.tools = tool_defs
else:
tool_defs = list(map(interop.ai_function_to_tool, self.functions.values()))
session_config = self.session.session_config.model_copy(update={"tools": tool_defs})
# send session config over WS
await self.session.send(client_events.SessionUpdate(session=session_config))
await self.session.wait_for("session.updated")
# send chat history over ws
if self.always_included_messages:
warnings.warn(
"Due to the server-managed nature of the OpenAI realtime API, messages marked as always included may"
" not always be included by the server."
)
for msg in self.always_included_messages:
for item in interop.chat_message_to_conv_items(msg):
await self.session.send(client_events.ConversationItemCreate(item=item))
await self.session.wait_for("conversation.item.created")
if self._chat_history:
for msg in self._chat_history:
for item in interop.chat_message_to_conv_items(msg):
await self.session.send(client_events.ConversationItemCreate(item=item))
await self.session.wait_for("conversation.item.created")
@property
def is_connected(self):
return self._has_connected
# ===== weird overrides =====
@property
def always_included_messages(self):
return self._always_included_messages
@always_included_messages.setter
def always_included_messages(self, value):
if self._has_connected:
raise ValueError("The chat history cannot be directly modified after connecting to the WS.")
self._always_included_messages = value
@property
def chat_history(self):
if not self._has_connected:
return self._chat_history
# todo read chat items from session
# todo return immutable
@chat_history.setter
def chat_history(self, value):
if self._has_connected:
raise ValueError("The chat history cannot be directly modified after connecting to the WS.")
self._chat_history = value
async def add_to_history(self, message: ChatMessage):
# intentionally do nothing here
# todo maybe call a conversation.item.add if not in session?
pass
async def get_prompt(self) -> list[ChatMessage]:
return []
# ===== kani iface =====
async def get_model_completion(self, include_functions: bool = True, **kwargs) -> Completion:
"""Request a completion now and return it."""
if not include_functions:
kwargs["tool_choice"] = "none"
await self.session.send(
client_events.ResponseCreate(response=self.session.session_config.model_copy(update=kwargs))
)
response_created_data: server_events.ResponseCreated = await self.session.wait_for("response.created")
response: server_events.ResponseDone = await self.session.wait_for(
"response.done", lambda e: e.response.id == response_created_data.response.id
)
message = interop.response_to_chat_message(response.response)
return Completion(
message=message, prompt_tokens=response.usage.input_tokens, completion_tokens=response.usage.output_tokens
)
async def get_model_stream(
self, include_functions: bool = True, audio_callback: Callable[[bytes], Awaitable] = None, **kwargs
) -> AsyncIterable[str | BaseCompletion]:
"""
Request a completion and stream from the model until the next response.done event. Only yield events from this
completion.
"""
if not include_functions:
kwargs["tool_choice"] = "none"
if audio_callback is None:
async def audio_callback(_):
pass
response_config = self.session.session_config.model_copy(update=kwargs) if self.session.session_config else None
await self.session.send(client_events.ResponseCreate(response=response_config))
response_created_data: server_events.ResponseCreated = await self.session.wait_for("response.created")
break_sentinel = object()
completion = None
q = asyncio.Queue()
async def listener(e):
match e:
case server_events.ResponseTextDelta(response_id=response_created_data.response.id, delta=text):
await q.put(text)
case server_events.ResponseAudioTranscriptDelta(
response_id=response_created_data.response.id, delta=text
):
await q.put(text)
case server_events.ResponseAudioDelta(response_id=response_created_data.response.id, delta=audio_b64):
await audio_callback(base64.b64decode(audio_b64))
case server_events.ResponseDone(response=response):
message = interop.response_to_chat_message(response)
nonlocal completion
completion = Completion(
message=message,
prompt_tokens=response.usage.input_tokens,
completion_tokens=response.usage.output_tokens,
)
await q.put(break_sentinel)
self.session.add_listener(listener)
try:
while True:
item = await q.get()
if item is break_sentinel:
log.debug("Got break sentinel, yielding completion")
break
yield item
finally:
self.session.remove_listener(listener)
if completion:
yield completion
async def _full_round(self, query: QueryType, *, max_function_rounds: int, _kani_is_stream: bool, **kwargs):
"""Underlying handler for full_round with stream support."""
retry = 0
function_rounds = 0
is_model_turn = True
if query is not None:
msg = ChatMessage.user(query)
await self.add_to_history(msg)
for item in interop.chat_message_to_conv_items(msg):
await self.session.send(client_events.ConversationItemCreate(item=item))
while is_model_turn:
# do the model prediction (stream or no stream)
if _kani_is_stream:
stream = self.get_model_stream(**kwargs)
manager = StreamManager(stream, role=ChatRole.ASSISTANT, after=self.add_completion_to_history)
yield manager
message = await manager.message()
else:
completion = await self.get_model_completion(**kwargs)
message = await self.add_completion_to_history(completion)
yield message
# if function call, do it and attempt retry if it's wrong
if not message.tool_calls:
return
# and update results after they are completed
is_model_turn = False
should_retry_call = False
n_errs = 0
results = await asyncio.gather(*(self._do_tool_call(tc, retry) for tc in message.tool_calls))
for result in results:
# save the result to the chat history
await self.add_to_history(result.message)
for item in interop.chat_message_to_conv_items(result.message):
await self.session.send(client_events.ConversationItemCreate(item=item))
# yield it, possibly in dummy streammanager
if _kani_is_stream:
yield DummyStream(result.message)
else:
yield result.message
if isinstance(result, ExceptionHandleResult):
is_model_turn = True
n_errs += 1
# retry if any function says so
should_retry_call = should_retry_call or result.should_retry
else:
# allow model to generate response if any function says so
is_model_turn = is_model_turn or result.is_model_turn
# if we encountered an error, increment the retry counter and allow the model to generate a response
if n_errs:
retry += 1
if not should_retry_call:
# disable function calling on the next go
kwargs["include_functions"] = False
else:
retry = 0
# if we're at the max number of function rounds, don't include functions on the next go
function_rounds += 1
if max_function_rounds is not None and function_rounds >= max_function_rounds:
kwargs["include_functions"] = False
async def _do_tool_call(self, tc: ToolCall, retry: int):
# call the method and set the is_tool_call_error attr (if the impl has not already set it)
try:
tc_result = await self.do_function_call(tc.function, tool_call_id=tc.id)
if tc_result.message.is_tool_call_error is None:
tc_result.message.is_tool_call_error = False
except FunctionCallException as e:
tc_result = await self.handle_function_call_exception(tc.function, e, retry, tool_call_id=tc.id)
tc_result.message.is_tool_call_error = True
return tc_result
async def close(self):
"""Disconnect from the WS."""
await self.session.close()
# ===== full duplex =====
[docs]
async def full_duplex(
self,
audio_stream: AsyncIterable[bytes], # todo what about manual response creates
audio_callback: Callable[[bytes], Awaitable] = None,
**kwargs, # todo this might be a good place for session config too?
) -> AsyncIterable[StreamManager]:
"""
Stream audio bytes from the given stream to the realtime model.
Yields a stream for each conversation item created (both USER and ASSISTANT). Each stream will be related to
exactly one conversation item (i.e., message), and multiple streams may emit simultaneously.
To consume tokens from a stream, use this class as so:
.. code-blocK:: python
stream_tasks = set()
async def handle_stream(stream):
# do processing for a single message's stream here...
# this example code does NOT account for multiple simultaneous messages
async for token in stream:
print(token, end="")
msg = await stream.message()
async for stream in ai.full_duplex(audio_stream):
task = asyncio.create_task(handle_stream(stream))
# to keep a live reference to the task
# see https://docs.python.org/3/library/asyncio-task.html#creating-tasks
stream_tasks.add(task)
task.add_done_callback(stream_tasks.discard)
Check out the implementation of :func:`.chat_in_terminal_audio_async` for more in-depth stream handling (e.g.,
printing out streams simultaneously without clobbering other messages' outputs).
Each :class:`.StreamManager` object yielded by this method contains a :attr:`.StreamManager.role` attribute
that can be used to determine if a message is from the user, engine or a function call. This attribute will be
available *before* iterating over the stream.
.. note::
This method will exit once the ``audio_stream`` is exhausted (i.e., the iterator raises StopAsyncIteration).
.. note::
For lower-level control over the realtime chat session (e.g. to send events directly to the server), see
:class:`.RealtimeSession` and :mod:`.events.client`. For example, you might use the following to request a
response when serverside VAD is disabled:
.. code-block:: python
from kani.ext.realtime import events
await ai.session.send(events.client.ResponseCreate())
See https://platform.openai.com/docs/api-reference/realtime-client-events for more details.
:param audio_stream: An async iterator that emits audio frames (bytes). Audio frames should be encoded as
raw 16 bit PCM audio at 24kHz, 1 channel, little-endian. See :func:`.get_audio_stream` to get such an
audio stream from a system microphone.
:param audio_callback: An async function that consumes audio frames as emitted by the model. Use
:func:`.play_audio` to play the audio from the system speaker.
"""
if audio_callback is None:
async def audio_callback(_):
pass
break_sentinel = object()
# streamer for item with given ID reads elements from their q, stored here
streamer_queues = collections.defaultdict(asyncio.Queue)
yielder_q = asyncio.Queue() # streamers to yield
# helper for yielding
async def yield_from_queue(q: asyncio.Queue):
while True:
item_to_yield = await q.get()
if item_to_yield is break_sentinel:
break
yield item_to_yield
# main event handler
async def listener(e):
"""On event from server, route the event to the right streamer or yield a new streamer"""
match e:
# ===== new conversation item =====
# we only care about messages here - function calls are handled elsewhere
case server_events.ConversationItemCreated(
item=oaimodels.MessageConversationItem(id=item_id, role=role)
):
streamer_q = streamer_queues[item_id]
await yielder_q.put(StreamManager(yield_from_queue(streamer_q), role=ChatRole(role)))
# ===== streaming items (asst) =====
case server_events.ResponseTextDelta(
item_id=item_id, delta=text
) | server_events.ResponseAudioTranscriptDelta(item_id=item_id, delta=text):
await streamer_queues[item_id].put(text)
case server_events.ResponseDone(response=response):
message = interop.response_to_chat_message(response)
completion = Completion(
message=message,
prompt_tokens=response.usage.input_tokens,
completion_tokens=response.usage.output_tokens,
)
for item_id in set(i.id for i in response.output if i.type == "message"):
q = streamer_queues[item_id]
await q.put(completion)
await q.put(break_sentinel)
streamer_queues.pop(item_id)
# ===== streaming items (user) =====
case server_events.ConversationItemInputAudioTranscriptionCompleted(item_id=item_id, transcript=text):
await streamer_queues[item_id].put(text.strip())
# emit a completion too
item = self.session.conversation_items.get(item_id)
assert isinstance(item, oaimodels.MessageConversationItem)
role = ChatRole(item.role)
content = list(map(interop.content_part_to_message_part, item.content))
message = ChatMessage(role=role, content=content)
completion = Completion(message=message, prompt_tokens=0, completion_tokens=0)
await streamer_queues[item_id].put(completion)
await streamer_queues[item_id].put(break_sentinel)
streamer_queues.pop(item_id)
# ===== audio =====
case server_events.ResponseAudioDelta(delta=audio_b64):
await audio_callback(base64.b64decode(audio_b64))
# ===== function calling =====
case server_events.ResponseOutputItemDone(
item=oaimodels.FunctionCallConversationItem(
status="completed", call_id=call_id, name=name, arguments=args
)
):
tc = ToolCall.from_function_call(FunctionCall(name=name, arguments=args), call_id)
# emit a dummystream with the function call
tc_message = ChatMessage.assistant(content=None, tool_calls=[tc])
await yielder_q.put(DummyStream(tc_message))
# actually call it and req a new completion with data
result = await self._do_tool_call(tc, 0)
# save the result to the chat history
await self.add_to_history(result.message)
for item in interop.chat_message_to_conv_items(result.message):
await self.session.send(client_events.ConversationItemCreate(item=item))
await yielder_q.put(DummyStream(result.message))
# request a new completion
await self.session.send(client_events.ResponseCreate())
# audio sender
async def audio_sender_task():
async for frame in audio_stream:
data = base64.b64encode(frame).decode()
await self.session.send(client_events.InputAudioBufferAppend(audio=data))
# when we are out of audio, tell the outer loop to break
await yielder_q.put(break_sentinel)
# add the listener, start the task to fwd audio frames, and start emitting
self.session.add_listener(listener)
audio_task = asyncio.create_task(audio_sender_task())
try:
async for stream_manager in yield_from_queue(yielder_q):
yield stream_manager
finally:
audio_task.cancel()
self.session.remove_listener(listener)