Skip to content

openai_gpt

langroid/language_models/openai_gpt.py

OpenAIChatModel

Bases: str, Enum

Enum for OpenAI Chat models

OpenAICompletionModel

Bases: str, Enum

Enum for OpenAI Completion models

OpenAICallParams

Bases: BaseModel

Various params that can be sent to an OpenAI API chat-completion call. When specified, any param here overrides the one with same name in the OpenAIGPTConfig.

OpenAIGPTConfig(**kwargs)

Bases: LLMConfig

Class for any LLM with an OpenAI-like API: besides the OpenAI models this includes: (a) locally-served models behind an OpenAI-compatible API (b) non-local models, using a proxy adaptor lib like litellm that provides an OpenAI-compatible API. We could rename this class to OpenAILikeConfig.

Source code in langroid/language_models/openai_gpt.py
def __init__(self, **kwargs) -> None:  # type: ignore
    local_model = "api_base" in kwargs and kwargs["api_base"] is not None

    chat_model = kwargs.get("chat_model", "")
    local_prefixes = ["local/", "litellm/", "ollama/"]
    if any(chat_model.startswith(prefix) for prefix in local_prefixes):
        local_model = True

    warn_gpt_3_5 = (
        "chat_model" not in kwargs.keys()
        and not local_model
        and defaultOpenAIChatModel == OpenAIChatModel.GPT3_5_TURBO
    )

    if warn_gpt_3_5:
        existing_hook = kwargs.get("run_on_first_use", noop)

        def with_warning() -> None:
            existing_hook()
            gpt_3_5_warning()

        kwargs["run_on_first_use"] = with_warning

    super().__init__(**kwargs)

create(prefix) classmethod

Create a config class whose params can be set via a desired prefix from the .env file or env vars. E.g., using

OllamaConfig = OpenAIGPTConfig.create("ollama")
ollama_config = OllamaConfig()
you can have a group of params prefixed by "OLLAMA_", to be used with models served via ollama. This way, you can maintain several setting-groups in your .env file, one per model type.

Source code in langroid/language_models/openai_gpt.py
@classmethod
def create(cls, prefix: str) -> Type["OpenAIGPTConfig"]:
    """Create a config class whose params can be set via a desired
    prefix from the .env file or env vars.
    E.g., using
    ```python
    OllamaConfig = OpenAIGPTConfig.create("ollama")
    ollama_config = OllamaConfig()
    ```
    you can have a group of params prefixed by "OLLAMA_", to be used
    with models served via `ollama`.
    This way, you can maintain several setting-groups in your .env file,
    one per model type.
    """

    class DynamicConfig(OpenAIGPTConfig):
        pass

    DynamicConfig.Config.env_prefix = prefix.upper() + "_"

    return DynamicConfig

OpenAIResponse

Bases: BaseModel

OpenAI response model, either completion or chat.

OpenAIGPT(config=OpenAIGPTConfig())

Bases: LanguageModel

Class for OpenAI LLMs

Source code in langroid/language_models/openai_gpt.py
def __init__(self, config: OpenAIGPTConfig = OpenAIGPTConfig()):
    """
    Args:
        config: configuration for openai-gpt model
    """
    # copy the config to avoid modifying the original
    config = config.copy()
    super().__init__(config)
    self.config: OpenAIGPTConfig = config

    # Run the first time the model is used
    self.run_on_first_use = cache(self.config.run_on_first_use)

    # global override of chat_model,
    # to allow quick testing with other models
    if settings.chat_model != "":
        self.config.chat_model = settings.chat_model
        self.config.completion_model = settings.chat_model

    if len(parts := self.config.chat_model.split("//")) > 1:
        # there is a formatter specified, e.g.
        # "litellm/ollama/mistral//hf" or
        # "local/localhost:8000/v1//mistral-instruct-v0.2"
        formatter = parts[1]
        self.config.chat_model = parts[0]
        if formatter == "hf":
            # e.g. "litellm/ollama/mistral//hf" -> "litellm/ollama/mistral"
            formatter = find_hf_formatter(self.config.chat_model)
            if formatter != "":
                # e.g. "mistral"
                self.config.formatter = formatter
                logging.warning(
                    f"""
                    Using completions (not chat) endpoint with HuggingFace 
                    chat_template for {formatter} for 
                    model {self.config.chat_model}
                    """
                )
        else:
            # e.g. "local/localhost:8000/v1//mistral-instruct-v0.2"
            self.config.formatter = formatter

    if self.config.formatter is not None:
        self.config.hf_formatter = HFFormatter(
            HFPromptFormatterConfig(model_name=self.config.formatter)
        )

    # if model name starts with "litellm",
    # set the actual model name by stripping the "litellm/" prefix
    # and set the litellm flag to True
    if self.config.chat_model.startswith("litellm/") or self.config.litellm:
        # e.g. litellm/ollama/mistral
        self.config.litellm = True
        self.api_base = self.config.api_base
        if self.config.chat_model.startswith("litellm/"):
            # strip the "litellm/" prefix
            # e.g. litellm/ollama/llama2 => ollama/llama2
            self.config.chat_model = self.config.chat_model.split("/", 1)[1]
    elif self.config.chat_model.startswith("local/"):
        # expect this to be of the form "local/localhost:8000/v1",
        # depending on how the model is launched locally.
        # In this case the model served locally behind an OpenAI-compatible API
        # so we can just use `openai.*` methods directly,
        # and don't need a adaptor library like litellm
        self.config.litellm = False
        self.config.seed = None  # some models raise an error when seed is set
        # Extract the api_base from the model name after the "local/" prefix
        self.api_base = self.config.chat_model.split("/", 1)[1]
        if not self.api_base.startswith("http"):
            self.api_base = "http://" + self.api_base
    elif self.config.chat_model.startswith("ollama/"):
        self.config.ollama = True
        self.api_base = OLLAMA_BASE_URL
        self.api_key = OLLAMA_API_KEY
        self.config.chat_model = self.config.chat_model.replace("ollama/", "")
    else:
        self.api_base = self.config.api_base

    if settings.chat_model != "":
        # if we're overriding chat model globally, set completion model to same
        self.config.completion_model = self.config.chat_model

    if self.config.formatter is not None:
        # we want to format chats -> completions using this specific formatter
        self.config.use_completion_for_chat = True
        self.config.completion_model = self.config.chat_model

    if self.config.use_completion_for_chat:
        self.config.use_chat_for_completion = False

    # NOTE: The api_key should be set in the .env file, or via
    # an explicit `export OPENAI_API_KEY=xxx` or `setenv OPENAI_API_KEY xxx`
    # Pydantic's BaseSettings will automatically pick it up from the
    # .env file
    # The config.api_key is ignored when not using an OpenAI model
    if self.is_openai_completion_model() or self.is_openai_chat_model():
        self.api_key = config.api_key
        if self.api_key == DUMMY_API_KEY:
            self.api_key = os.getenv("OPENAI_API_KEY", DUMMY_API_KEY)
    else:
        self.api_key = DUMMY_API_KEY

    self.is_groq = self.config.chat_model.startswith("groq/")

    if self.is_groq:
        self.config.chat_model = self.config.chat_model.replace("groq/", "")
        self.api_key = os.getenv("GROQ_API_KEY", DUMMY_API_KEY)
        self.client = Groq(
            api_key=self.api_key,
        )
        self.async_client = AsyncGroq(
            api_key=self.api_key,
        )
    else:
        self.client = OpenAI(
            api_key=self.api_key,
            base_url=self.api_base,
            organization=self.config.organization,
            timeout=Timeout(self.config.timeout),
        )
        self.async_client = AsyncOpenAI(
            api_key=self.api_key,
            organization=self.config.organization,
            base_url=self.api_base,
            timeout=Timeout(self.config.timeout),
        )

    self.cache: MomentoCache | RedisCache
    if settings.cache_type == "momento":
        if config.cache_config is None or isinstance(
            config.cache_config, RedisCacheConfig
        ):
            # switch to fresh momento config if needed
            config.cache_config = MomentoCacheConfig()
        self.cache = MomentoCache(config.cache_config)
    elif "redis" in settings.cache_type:
        if config.cache_config is None or isinstance(
            config.cache_config, MomentoCacheConfig
        ):
            # switch to fresh redis config if needed
            config.cache_config = RedisCacheConfig(
                fake="fake" in settings.cache_type
            )
        if "fake" in settings.cache_type:
            # force use of fake redis if global cache_type is "fakeredis"
            config.cache_config.fake = True
        self.cache = RedisCache(config.cache_config)
    else:
        raise ValueError(
            f"Invalid cache type {settings.cache_type}. "
            "Valid types are momento, redis, fakeredis"
        )

    self.config._validate_litellm()

chat_context_length()

Context-length for chat-completion models/endpoints Get it from the dict, otherwise fail-over to general method

Source code in langroid/language_models/openai_gpt.py
def chat_context_length(self) -> int:
    """
    Context-length for chat-completion models/endpoints
    Get it from the dict, otherwise fail-over to general method
    """
    model = (
        self.config.completion_model
        if self.config.use_completion_for_chat
        else self.config.chat_model
    )
    return _context_length.get(model, super().chat_context_length())

completion_context_length()

Context-length for completion models/endpoints Get it from the dict, otherwise fail-over to general method

Source code in langroid/language_models/openai_gpt.py
def completion_context_length(self) -> int:
    """
    Context-length for completion models/endpoints
    Get it from the dict, otherwise fail-over to general method
    """
    model = (
        self.config.chat_model
        if self.config.use_chat_for_completion
        else self.config.completion_model
    )
    return _context_length.get(model, super().completion_context_length())

chat_cost()

(Prompt, Generation) cost per 1000 tokens, for chat-completion models/endpoints. Get it from the dict, otherwise fail-over to general method

Source code in langroid/language_models/openai_gpt.py
def chat_cost(self) -> Tuple[float, float]:
    """
    (Prompt, Generation) cost per 1000 tokens, for chat-completion
    models/endpoints.
    Get it from the dict, otherwise fail-over to general method
    """
    return _cost_per_1k_tokens.get(self.config.chat_model, super().chat_cost())

set_stream(stream)

Enable or disable streaming output from API. Args: stream: enable streaming output from API Returns: previous value of stream

Source code in langroid/language_models/openai_gpt.py
def set_stream(self, stream: bool) -> bool:
    """Enable or disable streaming output from API.
    Args:
        stream: enable streaming output from API
    Returns: previous value of stream
    """
    tmp = self.config.stream
    self.config.stream = stream
    return tmp

get_stream()

Get streaming status

Source code in langroid/language_models/openai_gpt.py
def get_stream(self) -> bool:
    """Get streaming status"""
    return self.config.stream

noop()

Does nothing.

Source code in langroid/language_models/openai_gpt.py
def noop() -> None:
    """Does nothing."""
    return None

litellm_logging_fn(model_call_dict)

Logging function for litellm

Source code in langroid/language_models/openai_gpt.py
def litellm_logging_fn(model_call_dict: Dict[str, Any]) -> None:
    """Logging function for litellm"""
    try:
        api_input_dict = model_call_dict.get("additional_args", {}).get(
            "complete_input_dict"
        )
        if api_input_dict is not None:
            text = escape(json.dumps(api_input_dict, indent=2))
            print(
                f"[grey37]LITELLM: {text}[/grey37]",
            )
    except Exception:
        pass