Skip to content

openai_assistant

langroid/agent/openai_assistant.py

OpenAIAssistant(config)

Bases: ChatAgent

A ChatAgent powered by OpenAI Assistant API: mainly, in llm_response method, we avoid maintaining conversation state, and instead let the Assistant API do it for us. Also handles persistent storage of Assistant and Threads: stores their ids (for given user, org) in a cache, and reuses them based on config.use_cached_assistant and config.use_cached_thread.

This class can be used as a drop-in replacement for ChatAgent.

Source code in langroid/agent/openai_assistant.py
def __init__(self, config: OpenAIAssistantConfig):
    super().__init__(config)
    self.config: OpenAIAssistantConfig = config
    self.llm: OpenAIGPT = OpenAIGPT(self.config.llm)
    if not isinstance(self.llm.client, openai.OpenAI):
        raise ValueError("Client must be OpenAI")
    # handles for various entities and methods
    self.client: openai.OpenAI = self.llm.client
    self.runs = self.client.beta.threads.runs
    self.threads = self.client.beta.threads
    self.thread_messages = self.client.beta.threads.messages
    self.assistants = self.client.beta.assistants
    # which tool_ids are awaiting output submissions
    self.pending_tool_ids: List[str] = []
    self.cached_tool_ids: List[str] = []

    self.thread: Thread | None = None
    self.assistant: Assistant | None = None
    self.run: Run | None = None

    self._maybe_create_assistant(self.config.assistant_id)
    self._maybe_create_thread(self.config.thread_id)
    self._cache_store()

    self.add_assistant_files(self.config.files)
    self.add_assistant_tools(self.config.tools)

add_assistant_files(files)

Add file_ids to assistant

Source code in langroid/agent/openai_assistant.py
def add_assistant_files(self, files: List[str]) -> None:
    """Add file_ids to assistant"""
    if self.assistant is None:
        raise ValueError("Assistant is None")
    self.files = [
        self.client.files.create(file=open(f, "rb"), purpose="assistants")
        for f in files
    ]
    self.config.files = list(set(self.config.files + files))
    self.assistant = self.assistants.update(
        self.assistant.id,
        tool_resources=ToolResources(
            code_interpreter=ToolResourcesCodeInterpreter(
                file_ids=[f.id for f in self.files],
            ),
        ),
    )

add_assistant_tools(tools)

Add tools to assistant

Source code in langroid/agent/openai_assistant.py
def add_assistant_tools(self, tools: List[AssistantTool]) -> None:
    """Add tools to assistant"""
    if self.assistant is None:
        raise ValueError("Assistant is None")
    all_tool_dicts = [t.dct() for t in self.config.tools]
    for t in tools:
        if t.dct() not in all_tool_dicts:
            self.config.tools.append(t)
    self.assistant = self.assistants.update(
        self.assistant.id,
        tools=[tool.dct() for tool in self.config.tools],  # type: ignore
    )

enable_message(message_class, use=True, handle=True, force=False, require_recipient=False, include_defaults=True)

Override ChatAgent's method: extract the function-related args. See that method for details. But specifically about the include_defaults arg: Normally the OpenAI completion API ignores these fields, but the Assistant fn-calling seems to pay attn to these, and if we don't want this, we should set this to False.

Source code in langroid/agent/openai_assistant.py
def enable_message(
    self,
    message_class: Optional[Type[ToolMessage]],
    use: bool = True,
    handle: bool = True,
    force: bool = False,
    require_recipient: bool = False,
    include_defaults: bool = True,
) -> None:
    """Override ChatAgent's method: extract the function-related args.
    See that method for details. But specifically about the `include_defaults` arg:
    Normally the OpenAI completion API ignores these fields, but the Assistant
    fn-calling seems to pay attn to these, and if we don't want this,
    we should set this to False.
    """
    super().enable_message(
        message_class,
        use=use,
        handle=handle,
        force=force,
        require_recipient=require_recipient,
        include_defaults=include_defaults,
    )
    if message_class is None or not use:
        # no specific msg class, or
        # we are not enabling USAGE/GENERATION of this tool/fn,
        # then there's no need to attach the fn to the assistant
        # (HANDLING the fn will still work via self.agent_response)
        return
    if self.config.use_tools:
        sys_msg = self._create_system_and_tools_message()
        self.set_system_message(sys_msg.content)
    if not self.config.use_functions_api:
        return
    functions, _ = self._function_args()
    if functions is None:
        return
    # add the functions to the assistant:
    if self.assistant is None:
        raise ValueError("Assistant is None")
    tools = self.assistant.tools
    tools.extend(
        [
            {
                "type": "function",  # type: ignore
                "function": f.dict(),
            }
            for f in functions
        ]
    )
    self.assistant = self.assistants.update(
        self.assistant.id,
        tools=tools,  # type: ignore
    )

thread_msg_to_llm_msg(msg) staticmethod

Convert a Message to an LLMMessage

Source code in langroid/agent/openai_assistant.py
@staticmethod
def thread_msg_to_llm_msg(msg: Message) -> LLMMessage:
    """
    Convert a Message to an LLMMessage
    """
    return LLMMessage(
        content=msg.content[0].text.value,  # type: ignore
        role=Role(msg.role),
    )

set_system_message(msg)

Override ChatAgent's method. The Task may use this method to set the system message of the chat assistant.

Source code in langroid/agent/openai_assistant.py
def set_system_message(self, msg: str) -> None:
    """
    Override ChatAgent's method.
    The Task may use this method to set the system message
    of the chat assistant.
    """
    super().set_system_message(msg)
    if self.assistant is None:
        raise ValueError("Assistant is None")
    self.assistant = self.assistants.update(self.assistant.id, instructions=msg)

process_citations(thread_msg)

Process citations in the thread message. Modifies the thread message in-place.

Source code in langroid/agent/openai_assistant.py
def process_citations(self, thread_msg: Message) -> None:
    """
    Process citations in the thread message.
    Modifies the thread message in-place.
    """
    # could there be multiple content items?
    # TODO content could be MessageContentImageFile; handle that later
    annotated_content = thread_msg.content[0].text  # type: ignore
    annotations = annotated_content.annotations
    citations = []
    # Iterate over the annotations and add footnotes
    for index, annotation in enumerate(annotations):
        # Replace the text with a footnote
        annotated_content.value = annotated_content.value.replace(
            annotation.text, f" [{index}]"
        )
        # Gather citations based on annotation attributes
        if file_citation := getattr(annotation, "file_citation", None):
            try:
                cited_file = self.client.files.retrieve(file_citation.file_id)
            except Exception:
                logger.warning(
                    f"""
                    Could not retrieve cited file with id {file_citation.file_id}, 
                    ignoring. 
                    """
                )
                continue
            citations.append(
                f"[{index}] '{file_citation.quote}',-- from {cited_file.filename}"
            )
        elif file_path := getattr(annotation, "file_path", None):
            cited_file = self.client.files.retrieve(file_path.file_id)
            citations.append(
                f"[{index}] Click <here> to download {cited_file.filename}"
            )
        # Note: File download functionality not implemented above for brevity
    sep = "\n" if len(citations) > 0 else ""
    annotated_content.value += sep + "\n".join(citations)

llm_response(message=None)

Override ChatAgent's method: this is the main LLM response method. In the ChatAgent, this updates self.message_history and then calls self.llm_response_messages, but since we are relying on the Assistant API to maintain conversation state, this method is simpler: Simply start a run on the message-thread, and wait for it to complete.

Parameters:

Name Type Description Default
message Optional[str | ChatDocument]

message to respond to (if absent, the LLM response will be based on the instructions in the system_message). Defaults to None.

None

Returns: Optional[ChatDocument]: LLM response

Source code in langroid/agent/openai_assistant.py
def llm_response(
    self, message: Optional[str | ChatDocument] = None
) -> Optional[ChatDocument]:
    """
    Override ChatAgent's method: this is the main LLM response method.
    In the ChatAgent, this updates `self.message_history` and then calls
    `self.llm_response_messages`, but since we are relying on the Assistant API
    to maintain conversation state, this method is simpler: Simply start a run
    on the message-thread, and wait for it to complete.

    Args:
        message (Optional[str | ChatDocument], optional): message to respond to
            (if absent, the LLM response will be based on the
            instructions in the system_message). Defaults to None.
    Returns:
        Optional[ChatDocument]: LLM response
    """
    response = self._llm_response_preprocess(message)
    cached = True
    if response is None:
        cached = False
        response = self._run_result()
    return self._llm_response_postprocess(response, cached=cached, message=message)

llm_response_async(message=None) async

Async version of llm_response.

Source code in langroid/agent/openai_assistant.py
async def llm_response_async(
    self, message: Optional[str | ChatDocument] = None
) -> Optional[ChatDocument]:
    """
    Async version of llm_response.
    """
    response = self._llm_response_preprocess(message)
    cached = True
    if response is None:
        cached = False
        response = await self._run_result_async()
    return self._llm_response_postprocess(response, cached=cached, message=message)