Skip to content

openai_gpt

langroid/language_models/openai_gpt.py

AnthropicModel

Bases: str, Enum

Enum for Anthropic models

OpenAIChatModel

Bases: str, Enum

Enum for OpenAI Chat models

GeminiModel

Bases: str, Enum

Enum for Gemini models

OpenAICompletionModel

Bases: str, Enum

Enum for OpenAI Completion models

OpenAICallParams

Bases: BaseModel

Various params that can be sent to an OpenAI API chat-completion call. When specified, any param here overrides the one with same name in the OpenAIGPTConfig. See OpenAI API Reference for details on the params: https://platform.openai.com/docs/api-reference/chat

OpenAIGPTConfig(**kwargs)

Bases: LLMConfig

Class for any LLM with an OpenAI-like API: besides the OpenAI models this includes: (a) locally-served models behind an OpenAI-compatible API (b) non-local models, using a proxy adaptor lib like litellm that provides an OpenAI-compatible API. We could rename this class to OpenAILikeConfig.

Source code in langroid/language_models/openai_gpt.py
def __init__(self, **kwargs) -> None:  # type: ignore
    local_model = "api_base" in kwargs and kwargs["api_base"] is not None

    chat_model = kwargs.get("chat_model", "")
    local_prefixes = ["local/", "litellm/", "ollama/"]
    if any(chat_model.startswith(prefix) for prefix in local_prefixes):
        local_model = True

    warn_gpt_3_5 = (
        "chat_model" not in kwargs.keys()
        and not local_model
        and defaultOpenAIChatModel == OpenAIChatModel.GPT3_5_TURBO
    )

    if warn_gpt_3_5:
        existing_hook = kwargs.get("run_on_first_use", noop)

        def with_warning() -> None:
            existing_hook()
            gpt_3_5_warning()

        kwargs["run_on_first_use"] = with_warning

    super().__init__(**kwargs)

create(prefix) classmethod

Create a config class whose params can be set via a desired prefix from the .env file or env vars. E.g., using

OllamaConfig = OpenAIGPTConfig.create("ollama")
ollama_config = OllamaConfig()
you can have a group of params prefixed by "OLLAMA_", to be used with models served via ollama. This way, you can maintain several setting-groups in your .env file, one per model type.

Source code in langroid/language_models/openai_gpt.py
@classmethod
def create(cls, prefix: str) -> Type["OpenAIGPTConfig"]:
    """Create a config class whose params can be set via a desired
    prefix from the .env file or env vars.
    E.g., using
    ```python
    OllamaConfig = OpenAIGPTConfig.create("ollama")
    ollama_config = OllamaConfig()
    ```
    you can have a group of params prefixed by "OLLAMA_", to be used
    with models served via `ollama`.
    This way, you can maintain several setting-groups in your .env file,
    one per model type.
    """

    class DynamicConfig(OpenAIGPTConfig):
        pass

    DynamicConfig.Config.env_prefix = prefix.upper() + "_"

    return DynamicConfig

OpenAIResponse

Bases: BaseModel

OpenAI response model, either completion or chat.

OpenAIGPT(config=OpenAIGPTConfig())

Bases: LanguageModel

Class for OpenAI LLMs

Source code in langroid/language_models/openai_gpt.py
def __init__(self, config: OpenAIGPTConfig = OpenAIGPTConfig()):
    """
    Args:
        config: configuration for openai-gpt model
    """
    # copy the config to avoid modifying the original
    config = config.copy()
    super().__init__(config)
    self.config: OpenAIGPTConfig = config

    # Run the first time the model is used
    self.run_on_first_use = cache(self.config.run_on_first_use)

    # global override of chat_model,
    # to allow quick testing with other models
    if settings.chat_model != "":
        self.config.chat_model = settings.chat_model
        self.config.completion_model = settings.chat_model

    if len(parts := self.config.chat_model.split("//")) > 1:
        # there is a formatter specified, e.g.
        # "litellm/ollama/mistral//hf" or
        # "local/localhost:8000/v1//mistral-instruct-v0.2"
        formatter = parts[1]
        self.config.chat_model = parts[0]
        if formatter == "hf":
            # e.g. "litellm/ollama/mistral//hf" -> "litellm/ollama/mistral"
            formatter = find_hf_formatter(self.config.chat_model)
            if formatter != "":
                # e.g. "mistral"
                self.config.formatter = formatter
                logging.warning(
                    f"""
                    Using completions (not chat) endpoint with HuggingFace 
                    chat_template for {formatter} for 
                    model {self.config.chat_model}
                    """
                )
        else:
            # e.g. "local/localhost:8000/v1//mistral-instruct-v0.2"
            self.config.formatter = formatter

    if self.config.formatter is not None:
        self.config.hf_formatter = HFFormatter(
            HFPromptFormatterConfig(model_name=self.config.formatter)
        )

    # if model name starts with "litellm",
    # set the actual model name by stripping the "litellm/" prefix
    # and set the litellm flag to True
    if self.config.chat_model.startswith("litellm/") or self.config.litellm:
        # e.g. litellm/ollama/mistral
        self.config.litellm = True
        self.api_base = self.config.api_base
        if self.config.chat_model.startswith("litellm/"):
            # strip the "litellm/" prefix
            # e.g. litellm/ollama/llama2 => ollama/llama2
            self.config.chat_model = self.config.chat_model.split("/", 1)[1]
    elif self.config.chat_model.startswith("local/"):
        # expect this to be of the form "local/localhost:8000/v1",
        # depending on how the model is launched locally.
        # In this case the model served locally behind an OpenAI-compatible API
        # so we can just use `openai.*` methods directly,
        # and don't need a adaptor library like litellm
        self.config.litellm = False
        self.config.seed = None  # some models raise an error when seed is set
        # Extract the api_base from the model name after the "local/" prefix
        self.api_base = self.config.chat_model.split("/", 1)[1]
        if not self.api_base.startswith("http"):
            self.api_base = "http://" + self.api_base
    elif self.config.chat_model.startswith("ollama/"):
        self.config.ollama = True

        # use api_base from config if set, else fall back on OLLAMA_BASE_URL
        self.api_base = self.config.api_base or OLLAMA_BASE_URL
        self.api_key = OLLAMA_API_KEY
        self.config.chat_model = self.config.chat_model.replace("ollama/", "")
    else:
        self.api_base = self.config.api_base

    if settings.chat_model != "":
        # if we're overriding chat model globally, set completion model to same
        self.config.completion_model = self.config.chat_model

    if self.config.formatter is not None:
        # we want to format chats -> completions using this specific formatter
        self.config.use_completion_for_chat = True
        self.config.completion_model = self.config.chat_model

    if self.config.use_completion_for_chat:
        self.config.use_chat_for_completion = False

    # NOTE: The api_key should be set in the .env file, or via
    # an explicit `export OPENAI_API_KEY=xxx` or `setenv OPENAI_API_KEY xxx`
    # Pydantic's BaseSettings will automatically pick it up from the
    # .env file
    # The config.api_key is ignored when not using an OpenAI model
    if self.is_openai_completion_model() or self.is_openai_chat_model():
        self.api_key = config.api_key
        if self.api_key == DUMMY_API_KEY:
            self.api_key = os.getenv("OPENAI_API_KEY", DUMMY_API_KEY)
    else:
        self.api_key = DUMMY_API_KEY

    self.is_groq = self.config.chat_model.startswith("groq/")
    self.is_cerebras = self.config.chat_model.startswith("cerebras/")
    self.is_gemini = self.config.chat_model.startswith("gemini/")

    if self.is_groq:
        self.config.chat_model = self.config.chat_model.replace("groq/", "")
        self.api_key = os.getenv("GROQ_API_KEY", DUMMY_API_KEY)
        self.client = Groq(
            api_key=self.api_key,
        )
        self.async_client = AsyncGroq(
            api_key=self.api_key,
        )
    elif self.is_cerebras:
        self.config.chat_model = self.config.chat_model.replace("cerebras/", "")
        self.api_key = os.getenv("CEREBRAS_API_KEY", DUMMY_API_KEY)
        self.client = Cerebras(
            api_key=self.api_key,
        )
        # TODO there is not async client, so should we do anything here?
        self.async_client = AsyncCerebras(
            api_key=self.api_key,
        )
    else:
        if self.is_gemini:
            self.config.chat_model = self.config.chat_model.replace("gemini/", "")
            self.api_key = os.getenv("GEMINI_API_KEY", DUMMY_API_KEY)
            self.api_base = GEMINI_BASE_URL

        self.client = OpenAI(
            api_key=self.api_key,
            base_url=self.api_base,
            organization=self.config.organization,
            timeout=Timeout(self.config.timeout),
        )
        self.async_client = AsyncOpenAI(
            api_key=self.api_key,
            organization=self.config.organization,
            base_url=self.api_base,
            timeout=Timeout(self.config.timeout),
        )

    self.cache: CacheDB | None = None
    use_cache = self.config.cache_config is not None
    if settings.cache_type == "momento" and use_cache:
        from langroid.cachedb.momento_cachedb import (
            MomentoCache,
            MomentoCacheConfig,
        )

        if config.cache_config is None or not isinstance(
            config.cache_config,
            MomentoCacheConfig,
        ):
            # switch to fresh momento config if needed
            config.cache_config = MomentoCacheConfig()
        self.cache = MomentoCache(config.cache_config)
    elif "redis" in settings.cache_type and use_cache:
        if config.cache_config is None or not isinstance(
            config.cache_config,
            RedisCacheConfig,
        ):
            # switch to fresh redis config if needed
            config.cache_config = RedisCacheConfig(
                fake="fake" in settings.cache_type
            )
        if "fake" in settings.cache_type:
            # force use of fake redis if global cache_type is "fakeredis"
            config.cache_config.fake = True
        self.cache = RedisCache(config.cache_config)
    elif settings.cache_type != "none" and use_cache:
        raise ValueError(
            f"Invalid cache type {settings.cache_type}. "
            "Valid types are momento, redis, fakeredis, none"
        )

    self.config._validate_litellm()

unsupported_params()

List of params that are not supported by the current model

Source code in langroid/language_models/openai_gpt.py
def unsupported_params(self) -> List[str]:
    """
    List of params that are not supported by the current model
    """
    match self.config.chat_model:
        case OpenAIChatModel.O1_MINI | OpenAIChatModel.O1_PREVIEW:
            return ["temperature", "stream"]
        case _:
            return []

rename_params()

Map of param name -> new name for specific models. Currently main troublemaker is o1* series.

Source code in langroid/language_models/openai_gpt.py
def rename_params(self) -> Dict[str, str]:
    """
    Map of param name -> new name for specific models.
    Currently main troublemaker is o1* series.
    """
    match self.config.chat_model:
        case (
            OpenAIChatModel.O1_MINI
            | OpenAIChatModel.O1_PREVIEW
            | GeminiModel.GEMINI_1_5_FLASH
            | GeminiModel.GEMINI_1_5_FLASH_8B
            | GeminiModel.GEMINI_1_5_PRO
        ):
            return {"max_tokens": "max_completion_tokens"}
        case _:
            return {}

chat_context_length()

Context-length for chat-completion models/endpoints Get it from the dict, otherwise fail-over to general method

Source code in langroid/language_models/openai_gpt.py
def chat_context_length(self) -> int:
    """
    Context-length for chat-completion models/endpoints
    Get it from the dict, otherwise fail-over to general method
    """
    model = (
        self.config.completion_model
        if self.config.use_completion_for_chat
        else self.config.chat_model
    )
    return _context_length.get(model, super().chat_context_length())

completion_context_length()

Context-length for completion models/endpoints Get it from the dict, otherwise fail-over to general method

Source code in langroid/language_models/openai_gpt.py
def completion_context_length(self) -> int:
    """
    Context-length for completion models/endpoints
    Get it from the dict, otherwise fail-over to general method
    """
    model = (
        self.config.chat_model
        if self.config.use_chat_for_completion
        else self.config.completion_model
    )
    return _context_length.get(model, super().completion_context_length())

chat_cost()

(Prompt, Generation) cost per 1000 tokens, for chat-completion models/endpoints. Get it from the dict, otherwise fail-over to general method

Source code in langroid/language_models/openai_gpt.py
def chat_cost(self) -> Tuple[float, float]:
    """
    (Prompt, Generation) cost per 1000 tokens, for chat-completion
    models/endpoints.
    Get it from the dict, otherwise fail-over to general method
    """
    return _cost_per_1k_tokens.get(self.config.chat_model, super().chat_cost())

set_stream(stream)

Enable or disable streaming output from API. Args: stream: enable streaming output from API Returns: previous value of stream

Source code in langroid/language_models/openai_gpt.py
def set_stream(self, stream: bool) -> bool:
    """Enable or disable streaming output from API.
    Args:
        stream: enable streaming output from API
    Returns: previous value of stream
    """
    tmp = self.config.stream
    self.config.stream = stream
    return tmp

get_stream()

Get streaming status. Note we disable streaming in quiet mode.

Source code in langroid/language_models/openai_gpt.py
def get_stream(self) -> bool:
    """Get streaming status. Note we disable streaming in quiet mode."""
    return (
        self.config.stream
        and settings.stream
        and self.config.chat_model not in NON_STREAMING_MODELS
        and not settings.quiet
    )

tool_deltas_to_tools(tools) staticmethod

Convert accumulated tool-call deltas to OpenAIToolCall objects. Adapted from this excellent code: https://community.openai.com/t/help-for-function-calls-with-streaming/627170/2

Parameters:

Name Type Description Default
tools List[Dict[str, Any]]

list of tool deltas received from streaming API

required

Returns:

Name Type Description
str str

plain text corresponding to tool calls that failed to parse

List[OpenAIToolCall]

List[OpenAIToolCall]: list of OpenAIToolCall objects

List[Dict[str, Any]]

List[Dict[str, Any]]: list of tool dicts (to reconstruct OpenAI API response, so it can be cached)

Source code in langroid/language_models/openai_gpt.py
@staticmethod
def tool_deltas_to_tools(tools: List[Dict[str, Any]]) -> Tuple[
    str,
    List[OpenAIToolCall],
    List[Dict[str, Any]],
]:
    """
    Convert accumulated tool-call deltas to OpenAIToolCall objects.
    Adapted from this excellent code:
     https://community.openai.com/t/help-for-function-calls-with-streaming/627170/2

    Args:
        tools: list of tool deltas received from streaming API

    Returns:
        str: plain text corresponding to tool calls that failed to parse
        List[OpenAIToolCall]: list of OpenAIToolCall objects
        List[Dict[str, Any]]: list of tool dicts
            (to reconstruct OpenAI API response, so it can be cached)
    """
    # Initialize a dictionary with default values

    # idx -> dict repr of tool
    # (used to simulate OpenAIResponse object later, and also to
    # accumulate function args as strings)
    idx2tool_dict: Dict[str, Dict[str, Any]] = defaultdict(
        lambda: {
            "id": None,
            "function": {"arguments": "", "name": None},
            "type": None,
        }
    )

    for tool_delta in tools:
        if tool_delta["id"] is not None:
            idx2tool_dict[tool_delta["index"]]["id"] = tool_delta["id"]

        if tool_delta["function"]["name"] is not None:
            idx2tool_dict[tool_delta["index"]]["function"]["name"] = tool_delta[
                "function"
            ]["name"]

        idx2tool_dict[tool_delta["index"]]["function"]["arguments"] += tool_delta[
            "function"
        ]["arguments"]

        if tool_delta["type"] is not None:
            idx2tool_dict[tool_delta["index"]]["type"] = tool_delta["type"]

    # (try to) parse the fn args of each tool
    contents: List[str] = []
    good_indices = []
    id2args: Dict[str, None | Dict[str, Any]] = {}
    for idx, tool_dict in idx2tool_dict.items():
        failed_content, args_dict = OpenAIGPT._parse_function_args(
            tool_dict["function"]["arguments"]
        )
        # used to build tool_calls_list below
        id2args[tool_dict["id"]] = args_dict or None  # if {}, store as None
        if failed_content != "":
            contents.append(failed_content)
        else:
            good_indices.append(idx)

    # remove the failed tool calls
    idx2tool_dict = {
        idx: tool_dict
        for idx, tool_dict in idx2tool_dict.items()
        if idx in good_indices
    }

    # create OpenAIToolCall list
    tool_calls_list = [
        OpenAIToolCall(
            id=tool_dict["id"],
            function=LLMFunctionCall(
                name=tool_dict["function"]["name"],
                arguments=id2args.get(tool_dict["id"]),
            ),
            type=tool_dict["type"],
        )
        for tool_dict in idx2tool_dict.values()
    ]
    return "\n".join(contents), tool_calls_list, list(idx2tool_dict.values())

noop()

Does nothing.

Source code in langroid/language_models/openai_gpt.py
def noop() -> None:
    """Does nothing."""
    return None

litellm_logging_fn(model_call_dict)

Logging function for litellm

Source code in langroid/language_models/openai_gpt.py
def litellm_logging_fn(model_call_dict: Dict[str, Any]) -> None:
    """Logging function for litellm"""
    try:
        api_input_dict = model_call_dict.get("additional_args", {}).get(
            "complete_input_dict"
        )
        if api_input_dict is not None:
            text = escape(json.dumps(api_input_dict, indent=2))
            print(
                f"[grey37]LITELLM: {text}[/grey37]",
            )
    except Exception:
        pass