Skip to content

sql_chat_agent

langroid/agent/special/sql/sql_chat_agent.py

Agent that allows interaction with an SQL database using SQLAlchemy library. The agent can execute SQL queries in the database and return the result.

Functionality includes: - adding table and column context - asking a question about a SQL schema

SQLChatAgentConfig

Bases: ChatAgentConfig

addressing_prefix: str = SEND_TO class-attribute instance-attribute

Optional, but strongly recommended, context descriptions for tables, columns, and relationships. It should be a dictionary where each key is a table name and its value is another dictionary.

In this inner dictionary: - The 'description' key corresponds to a string description of the table. - The 'columns' key corresponds to another dictionary where each key is a column name and its value is a string description of that column. - The 'relationships' key corresponds to another dictionary where each key is another table name and the value is a description of the relationship to that table.

If multi_schema support is enabled, the tables names in the description should be of the form 'schema_name.table_name'.

For example: { 'table1': { 'description': 'description of table1', 'columns': { 'column1': 'description of column1 in table1', 'column2': 'description of column2 in table1' } }, 'table2': { 'description': 'description of table2', 'columns': { 'column3': 'description of column3 in table2', 'column4': 'description of column4 in table2' } } }

SQLChatAgent(config)

Bases: ChatAgent

Agent for chatting with a SQL database

Raises:

Type Description
ValueError

If database information is not provided in the config.

Source code in langroid/agent/special/sql/sql_chat_agent.py
def __init__(self, config: "SQLChatAgentConfig") -> None:
    """Initialize the SQLChatAgent.

    Raises:
        ValueError: If database information is not provided in the config.
    """
    self._validate_config(config)
    self.config: SQLChatAgentConfig = config
    self._init_database()
    self._init_metadata()
    self._init_table_metadata()
    self._init_message_tools()

handle_message_fallback(msg)

Handle the scenario where current msg is not a tool. Special handling is only needed if the message was from the LLM (as indicated by self.llm_responded).

Source code in langroid/agent/special/sql/sql_chat_agent.py
def handle_message_fallback(
    self, msg: str | ChatDocument
) -> str | ChatDocument | None:
    """
    Handle the scenario where current msg is not a tool.
    Special handling is only needed if the message was from the LLM
    (as indicated by self.llm_responded).
    """
    if not self.llm_responded:
        return None
    if self.used_run_query:
        prefix = (
            self.config.addressing_prefix + "User"
            if self.config.addressing_prefix
            else ""
        )
        return (
            DONE + prefix + (msg.content if isinstance(msg, ChatDocument) else msg)
        )

    else:
        reminder = """
        You may have forgotten to use the `run_query` tool to execute an SQL query
        for the user's question/request            
        """
        if self.config.addressing_prefix != "":
            reminder += f"""
            OR you may have forgotten to address the user using the prefix
            {self.config.addressing_prefix} 
            """
        return reminder

retry_query(e, query)

Generate an error message for a failed SQL query and return it.

Parameters: e (Exception): The exception raised during the SQL query execution. query (str): The SQL query that failed.

Returns: str: The error message.

Source code in langroid/agent/special/sql/sql_chat_agent.py
def retry_query(self, e: Exception, query: str) -> str:
    """
    Generate an error message for a failed SQL query and return it.

    Parameters:
    e (Exception): The exception raised during the SQL query execution.
    query (str): The SQL query that failed.

    Returns:
    str: The error message.
    """
    logger.error(f"SQL Query failed: {query}\nException: {e}")

    # Optional part to be included based on `use_schema_tools`
    optional_schema_description = ""
    if not self.config.use_schema_tools:
        optional_schema_description = f"""\
        This JSON schema maps SQL database structure. It outlines tables, each 
        with a description and columns. Each table is identified by a key, and holds
        a description and a dictionary of columns, with column 
        names as keys and their descriptions as values.

        ```json
        {self.config.context_descriptions}
        ```"""

    # Construct the error message
    error_message_template = f"""\
    {SQL_ERROR_MSG}: '{query}'
    {str(e)}
    Run a new query, correcting the errors.
    {optional_schema_description}"""

    return error_message_template

run_query(msg)

Handle a RunQueryTool message by executing a SQL query and returning the result.

Parameters:

Name Type Description Default
msg RunQueryTool

The tool-message to handle.

required

Returns:

Name Type Description
str str

The result of executing the SQL query.

Source code in langroid/agent/special/sql/sql_chat_agent.py
def run_query(self, msg: RunQueryTool) -> str:
    """
    Handle a RunQueryTool message by executing a SQL query and returning the result.

    Args:
        msg (RunQueryTool): The tool-message to handle.

    Returns:
        str: The result of executing the SQL query.
    """
    query = msg.query
    session = self.Session
    self.used_run_query = True
    try:
        logger.info(f"Executing SQL query: {query}")

        query_result = session.execute(text(query))
        session.commit()
        try:
            # attempt to fetch results: should work for normal SELECT queries
            rows = query_result.fetchall()
            response_message = self._format_rows(rows)
        except ResourceClosedError:
            # If we get here, it's a non-SELECT query (UPDATE, INSERT, DELETE)
            affected_rows = query_result.rowcount  # type: ignore
            response_message = f"""
                Non-SELECT query executed successfully. 
                Rows affected: {affected_rows}
                """

    except SQLAlchemyError as e:
        session.rollback()
        logger.error(f"Failed to execute query: {query}\n{e}")
        response_message = self.retry_query(e, query)
    finally:
        session.close()

    return response_message

get_table_names(msg)

Handle a GetTableNamesTool message by returning the names of all tables in the database.

Returns:

Name Type Description
str str

The names of all tables in the database.

Source code in langroid/agent/special/sql/sql_chat_agent.py
def get_table_names(self, msg: GetTableNamesTool) -> str:
    """
    Handle a GetTableNamesTool message by returning the names of all tables in the
    database.

    Returns:
        str: The names of all tables in the database.
    """
    if isinstance(self.metadata, list):
        table_names = [", ".join(md.tables.keys()) for md in self.metadata]
        return ", ".join(table_names)

    return ", ".join(self.metadata.tables.keys())

get_table_schema(msg)

Handle a GetTableSchemaTool message by returning the schema of all provided tables in the database.

Returns:

Name Type Description
str str

The schema of all provided tables in the database.

Source code in langroid/agent/special/sql/sql_chat_agent.py
def get_table_schema(self, msg: GetTableSchemaTool) -> str:
    """
    Handle a GetTableSchemaTool message by returning the schema of all provided
    tables in the database.

    Returns:
        str: The schema of all provided tables in the database.
    """
    tables = msg.tables
    result = ""
    for table_name in tables:
        table = self.table_metadata.get(table_name)
        if table is not None:
            result += f"{table_name}: {table}\n"
        else:
            result += f"{table_name} is not a valid table name.\n"
    return result

get_column_descriptions(msg)

Handle a GetColumnDescriptionsTool message by returning the descriptions of all provided columns from the database.

Returns:

Name Type Description
str str

The descriptions of all provided columns from the database.

Source code in langroid/agent/special/sql/sql_chat_agent.py
def get_column_descriptions(self, msg: GetColumnDescriptionsTool) -> str:
    """
    Handle a GetColumnDescriptionsTool message by returning the descriptions of all
    provided columns from the database.

    Returns:
        str: The descriptions of all provided columns from the database.
    """
    table = msg.table
    columns = msg.columns.split(", ")
    result = f"\nTABLE: {table}"
    descriptions = self.config.context_descriptions.get(table)

    for col in columns:
        result += f"\n{col} => {descriptions['columns'][col]}"  # type: ignore
    return result