Skip to content
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ sagemaker = [
"openai>=1.68.0,<3.0.0", # SageMaker uses OpenAI-compatible interface
]
otel = ["opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0"]
# Rename this extra once compression/context management features land #555
token-estimation = ["tiktoken>=0.7.0,<1.0.0"]
Comment thread
opieter-aws marked this conversation as resolved.
Outdated
docs = [
"sphinx>=5.0.0,<10.0.0",
"sphinx-rtd-theme>=1.0.0,<4.0.0",
Expand Down
135 changes: 134 additions & 1 deletion src/strands/models/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Abstract base class for Agent model providers."""

import abc
import json
import logging
from collections.abc import AsyncGenerator, AsyncIterable
from dataclasses import dataclass
Expand All @@ -10,7 +11,7 @@

from ..hooks.events import AfterInvocationEvent
from ..plugins.plugin import Plugin
from ..types.content import Messages, SystemContentBlock
from ..types.content import ContentBlock, Messages, SystemContentBlock
from ..types.streaming import StreamEvent
from ..types.tools import ToolChoice, ToolSpec

Expand All @@ -21,6 +22,110 @@

T = TypeVar("T", bound=BaseModel)

_DEFAULT_ENCODING = "cl100k_base"
Comment thread
opieter-aws marked this conversation as resolved.
_cached_encoding: Any = None


def _get_encoding() -> Any:
"""Get the default tiktoken encoding, caching to avoid repeated lookups."""
global _cached_encoding
if _cached_encoding is None:
try:
import tiktoken
Comment thread
lizradway marked this conversation as resolved.
Outdated
except ImportError as err:
raise ImportError(
"tiktoken is required for token estimation. "
"Install it with: pip install strands-agents[token-estimation]"
) from err
_cached_encoding = tiktoken.get_encoding(_DEFAULT_ENCODING)
return _cached_encoding


def _count_content_block_tokens(block: ContentBlock, encoding: Any) -> int:
"""Count tokens for a single content block."""
total = 0

if "text" in block:
total += len(encoding.encode(block["text"]))

if "toolUse" in block:
tool_use = block["toolUse"]
total += len(encoding.encode(tool_use.get("name", "")))
try:
total += len(encoding.encode(json.dumps(tool_use.get("input", {}))))
except (TypeError, ValueError):
logger.debug(
"tool_name=<%s> | skipping non-serializable toolUse input for token estimation",
tool_use.get("name", "unknown"),
)

if "toolResult" in block:
tool_result = block["toolResult"]
for item in tool_result.get("content", []):
if "text" in item:
total += len(encoding.encode(item["text"]))

if "reasoningContent" in block:
reasoning = block["reasoningContent"]
if "reasoningText" in reasoning:
reasoning_text = reasoning["reasoningText"]
if "text" in reasoning_text:
total += len(encoding.encode(reasoning_text["text"]))

if "guardContent" in block:
guard = block["guardContent"]
if "text" in guard:
total += len(encoding.encode(guard["text"]["text"]))
Comment thread
lizradway marked this conversation as resolved.
Outdated

if "citationsContent" in block:
citations = block["citationsContent"]
if "content" in citations:
for citation_item in citations["content"]:
if "text" in citation_item:
total += len(encoding.encode(citation_item["text"]))

return total


def _estimate_tokens_with_tiktoken(
messages: Messages,
tool_specs: list[ToolSpec] | None = None,
system_prompt: str | None = None,
system_prompt_content: list[SystemContentBlock] | None = None,
) -> int:
"""Estimate tokens by serializing messages/tools to text and counting with tiktoken.

This is a best-effort fallback for providers that don't expose native counting.
Accuracy varies by model but is sufficient for threshold-based decisions.
"""
encoding = _get_encoding()
total = 0

# Prefer system_prompt_content (structured) over system_prompt (plain string) to avoid double-counting,
# since providers wrap system_prompt into system_prompt_content when both are provided.
if system_prompt_content:
for block in system_prompt_content:
if "text" in block:
total += len(encoding.encode(block["text"]))
elif system_prompt:
total += len(encoding.encode(system_prompt))

for message in messages:
for block in message["content"]:
total += _count_content_block_tokens(block, encoding)
Comment thread
opieter-aws marked this conversation as resolved.
Outdated

if tool_specs:
for spec in tool_specs:
try:
total += len(encoding.encode(json.dumps(spec)))
except (TypeError, ValueError):
logger.debug(
"tool_name=<%s> | skipping non-serializable tool spec for token estimation",
spec.get("name", "unknown"),
)

return total


@dataclass
class CacheConfig:
Expand Down Expand Up @@ -130,6 +235,34 @@ def stream(
"""
pass

def _estimate_tokens(
self,
messages: Messages,
tool_specs: list[ToolSpec] | None = None,
system_prompt: str | None = None,
system_prompt_content: list[SystemContentBlock] | None = None,
) -> int:
"""Estimate token count for the given input before sending to the model.

Used for proactive context management (e.g., triggering compression at a
threshold). This is a naive approximation using tiktoken's cl100k_base encoding.
Accuracy varies by model provider but is typically within 5-10% for most providers.
Comment thread
lizradway marked this conversation as resolved.
Outdated
Not intended for billing or precise quota calculations.

Subclasses may override this method to provide model-specific token counting
using native APIs for improved accuracy.

Args:
messages: List of message objects to estimate tokens for.
tool_specs: List of tool specifications to include in the estimate.
system_prompt: Plain string system prompt. Ignored if system_prompt_content is provided.
system_prompt_content: Structured system prompt content blocks. Takes priority over system_prompt.

Returns:
Estimated total input tokens.
"""
return _estimate_tokens_with_tiktoken(messages, tool_specs, system_prompt, system_prompt_content)


class _ModelPlugin(Plugin):
"""Plugin that manages model-related lifecycle hooks."""
Expand Down
Loading
Loading