Files
clawgo/pkg/providers/openai_compat_provider_test.go

328 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package providers
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)
func TestBuildQwenChatRequestStripsSuffixAndAppliesThinking(t *testing.T) {
base := NewHTTPProvider("qwen", "token", qwenCompatBaseURL, "qwen-max", false, "oauth", 5*time.Second, nil)
body := buildQwenChatRequest(base, []Message{{Role: "user", Content: "hi"}}, nil, "qwen-max(high)", nil, false)
if got := body["model"]; got != "qwen-max" {
t.Fatalf("model = %#v, want qwen-max", got)
}
if got := body["reasoning_effort"]; got != "high" {
t.Fatalf("reasoning_effort = %#v, want high", got)
}
}
func TestBuildQwenChatRequestAddsPoisonToolForStreamingWithoutTools(t *testing.T) {
base := NewHTTPProvider("qwen", "token", qwenCompatBaseURL, "qwen-max", false, "oauth", 5*time.Second, nil)
body := buildQwenChatRequest(base, []Message{{Role: "user", Content: "hi"}}, nil, "qwen-max", nil, true)
tools, ok := body["tools"].([]map[string]interface{})
if !ok || len(tools) != 1 {
t.Fatalf("tools = %#v, want single poison tool", body["tools"])
}
function, _ := tools[0]["function"].(map[string]interface{})
if got := function["name"]; got != "do_not_call_me" {
t.Fatalf("tool name = %#v, want do_not_call_me", got)
}
}
func TestClassifyQwenFailureMapsQuotaTo429UntilNextMidnight(t *testing.T) {
status, reason, retry, retryAfter := classifyQwenFailure(http.StatusForbidden, []byte(`{"error":{"code":"insufficient_quota","message":"free allocated quota exceeded"}}`))
if status != http.StatusTooManyRequests {
t.Fatalf("status = %d, want %d", status, http.StatusTooManyRequests)
}
if reason != oauthFailureQuota || !retry {
t.Fatalf("reason=%q retry=%v", reason, retry)
}
if retryAfter == nil || *retryAfter <= 0 || *retryAfter > 24*time.Hour {
t.Fatalf("retryAfter = %#v, want within next day", retryAfter)
}
}
func TestQwenProviderChatMapsQuota403To429(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte(`{"error":{"code":"insufficient_quota","message":"free allocated quota exceeded"}}`))
}))
defer server.Close()
provider := NewQwenProvider("qwen-quota", "token", server.URL, "qwen-max", false, "api_key", 5*time.Second, nil)
_, err := provider.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "qwen-max", nil)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "status 429") {
t.Fatalf("error = %v, want mapped 429", err)
}
}
func TestQwenProviderCountTokens(t *testing.T) {
provider := NewQwenProvider("qwen", "token", qwenCompatBaseURL, "qwen-max", false, "api_key", 5*time.Second, nil)
usage, err := provider.CountTokens(t.Context(), []Message{{Role: "user", Content: "hello qwen"}}, nil, "qwen-max", nil)
if err != nil {
t.Fatalf("CountTokens error: %v", err)
}
if usage == nil || usage.PromptTokens <= 0 || usage.TotalTokens != usage.PromptTokens {
t.Fatalf("usage = %#v, want positive prompt-only count", usage)
}
}
func TestApplyAttemptProviderHeadersQwenUsesDynamicStainlessValues(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, qwenCompatBaseURL+"/chat/completions", nil)
if err != nil {
t.Fatalf("new request: %v", err)
}
provider := &HTTPProvider{
oauth: &oauthManager{cfg: oauthConfig{Provider: defaultQwenOAuthProvider}},
}
applyAttemptProviderHeaders(req, authAttempt{kind: "oauth", token: "qwen-token"}, provider, true)
if got := req.Header.Get("X-Stainless-Arch"); got != qwenStainlessArch() {
t.Fatalf("X-Stainless-Arch = %q, want %q", got, qwenStainlessArch())
}
if got := req.Header.Get("X-Stainless-Os"); got != qwenStainlessOS() {
t.Fatalf("X-Stainless-Os = %q, want %q", got, qwenStainlessOS())
}
if got := req.Header.Get("Accept"); got != "text/event-stream" {
t.Fatalf("Accept = %q, want text/event-stream", got)
}
}
func TestNormalizeQwenResourceURL(t *testing.T) {
tests := []struct {
in string
want string
}{
{in: "https://chat.qwen.ai/api", want: "https://chat.qwen.ai/v1"},
{in: "chat.qwen.ai/api", want: "https://chat.qwen.ai/v1"},
{in: "https://portal.qwen.ai/v1", want: "https://portal.qwen.ai/v1"},
}
for _, tt := range tests {
if got := normalizeQwenResourceURL(tt.in); got != tt.want {
t.Fatalf("normalizeQwenResourceURL(%q) = %q, want %q", tt.in, got, tt.want)
}
}
}
func TestQwenHookUsesSessionResourceURL(t *testing.T) {
hooks := qwenProviderHooks{}
base := NewHTTPProvider("qwen", "token", qwenCompatBaseURL, "qwen-max", false, "oauth", 5*time.Second, nil)
got := hooks.endpoint(base, authAttempt{
kind: "oauth",
session: &oauthSession{
ResourceURL: "https://chat.qwen.ai/api",
},
}, "/chat/completions")
if got != "https://chat.qwen.ai/v1/chat/completions" {
t.Fatalf("endpoint = %q", got)
}
}
func TestOpenAICompatMessagesPreserveMultimodalContentParts(t *testing.T) {
msgs := openAICompatMessages([]Message{{
Role: "user",
ContentParts: []MessageContentPart{
{Type: "text", Text: "look"},
{Type: "input_image", ImageURL: "https://example.com/cat.png", Detail: "high"},
},
}})
if len(msgs) != 1 {
t.Fatalf("messages len = %d", len(msgs))
}
content, ok := msgs[0]["content"].([]map[string]interface{})
if !ok || len(content) != 2 {
t.Fatalf("content = %#v", msgs[0]["content"])
}
if got := content[0]["type"]; got != "text" {
t.Fatalf("first part type = %#v", got)
}
imagePart, _ := content[1]["image_url"].(map[string]interface{})
if got := content[1]["type"]; got != "image_url" {
t.Fatalf("second part type = %#v", got)
}
if got := imagePart["url"]; got != "https://example.com/cat.png" {
t.Fatalf("image url = %#v", got)
}
if got := imagePart["detail"]; got != "high" {
t.Fatalf("image detail = %#v", got)
}
}
func TestBuildOpenAICompatChatRequestAppliesThinkingSuffix(t *testing.T) {
base := NewHTTPProvider("openai", "token", "https://example.com/v1", "gpt-5", false, "api_key", 5*time.Second, nil)
body := base.buildOpenAICompatChatRequest([]Message{{Role: "user", Content: "hi"}}, nil, "gpt-5(high)", nil)
if got := body["model"]; got != "gpt-5" {
t.Fatalf("model = %#v, want gpt-5", got)
}
if got := body["reasoning_effort"]; got != "high" {
t.Fatalf("reasoning_effort = %#v, want high", got)
}
}
func TestBuildOpenAICompatChatRequestStripsKimiPrefixAndSuffix(t *testing.T) {
base := NewHTTPProvider("kimi", "token", kimiCompatBaseURL, "kimi-k2.5", false, "oauth", 5*time.Second, nil)
base.oauth = &oauthManager{cfg: oauthConfig{Provider: defaultKimiOAuthProvider}}
body := base.buildOpenAICompatChatRequest([]Message{{Role: "user", Content: "hi"}}, nil, "kimi-k2.5(-1)", nil)
if got := body["model"]; got != "k2.5" {
t.Fatalf("model = %#v, want k2.5", got)
}
if got := body["reasoning_effort"]; got != "auto" {
t.Fatalf("reasoning_effort = %#v, want auto", got)
}
}
func TestHTTPProviderChatUsesConfiguredChatCompletionsAPI(t *testing.T) {
var gotPath string
var gotBody map[string]interface{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
if err := json.NewDecoder(r.Body).Decode(&gotBody); err != nil {
t.Fatalf("decode request: %v", err)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"hello from chat"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}`))
}))
defer server.Close()
provider := NewHTTPProvider("openai", "token", server.URL+"/v1", "gpt-5", false, "api_key", 5*time.Second, nil)
provider.responsesAPI = "chat_completions"
resp, err := provider.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "gpt-5", nil)
if err != nil {
t.Fatalf("Chat error: %v", err)
}
if gotPath != "/v1/chat/completions" {
t.Fatalf("path = %q, want /v1/chat/completions", gotPath)
}
if gotBody["model"] != "gpt-5" {
t.Fatalf("model = %#v, want gpt-5", gotBody["model"])
}
if resp.Content != "hello from chat" {
t.Fatalf("content = %q, want hello from chat", resp.Content)
}
if resp.Usage == nil || resp.Usage.TotalTokens != 3 {
t.Fatalf("usage = %#v, want total_tokens=3", resp.Usage)
}
}
func TestParseOpenAICompatResponseCapturesReasoningContent(t *testing.T) {
resp, err := parseOpenAICompatResponse([]byte(`{"choices":[{"message":{"content":"answer","reasoning_content":"hidden chain"},"finish_reason":"stop"}]}`))
if err != nil {
t.Fatalf("parseOpenAICompatResponse error: %v", err)
}
if resp.ReasoningContent != "hidden chain" {
t.Fatalf("ReasoningContent = %q, want hidden chain", resp.ReasoningContent)
}
}
func TestOpenAICompatMessagesIncludeReasoningContent(t *testing.T) {
msgs := openAICompatMessages([]Message{{
Role: "assistant",
Content: "tool plan",
ReasoningContent: "thinking trace",
ToolCalls: []ToolCall{{
ID: "call_1",
Name: "read_file",
Function: &FunctionCall{
Name: "read_file",
Arguments: `{"path":"a.txt"}`,
},
}},
}})
if len(msgs) != 1 {
t.Fatalf("messages len = %d", len(msgs))
}
if got := msgs[0]["reasoning_content"]; got != "thinking trace" {
t.Fatalf("reasoning_content = %#v, want thinking trace", got)
}
}
func TestNormalizeOpenAICompatThinkingMessagesBackfillsReasoningForToolCalls(t *testing.T) {
body := map[string]interface{}{
"messages": []map[string]interface{}{
{
"role": "assistant",
"tool_calls": []map[string]interface{}{
{"id": "call_1"},
},
"content": "thinking content",
},
},
}
normalizeOpenAICompatThinkingMessages(body)
msgs := body["messages"].([]map[string]interface{})
if got := msgs[0]["reasoning_content"]; got != "thinking content" {
t.Fatalf("reasoning_content = %#v, want thinking content", got)
}
}
func TestHTTPProviderChatConfiguredCompatBackfillsReasoningContentForToolHistory(t *testing.T) {
var gotBody map[string]interface{}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := json.NewDecoder(r.Body).Decode(&gotBody); err != nil {
t.Fatalf("decode request: %v", err)
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"choices":[{"message":{"content":"ok"},"finish_reason":"stop"}]}`))
}))
defer server.Close()
provider := NewHTTPProvider("openai", "token", server.URL+"/v1", "gpt-5", false, "api_key", 5*time.Second, nil)
provider.responsesAPI = "chat_completions"
_, err := provider.Chat(t.Context(), []Message{
{Role: "user", Content: "hello"},
{
Role: "assistant",
Content: "thinking content",
ToolCalls: []ToolCall{{
ID: "call_1",
Name: "read_file",
Function: &FunctionCall{
Name: "read_file",
Arguments: `{"path":"a.txt"}`,
},
}},
},
{Role: "tool", ToolCallID: "call_1", Content: "file body"},
}, nil, "gpt-5(high)", nil)
if err != nil {
t.Fatalf("Chat error: %v", err)
}
rawMsgs, _ := gotBody["messages"].([]interface{})
if len(rawMsgs) < 2 {
t.Fatalf("messages = %#v", gotBody["messages"])
}
assistant, _ := rawMsgs[1].(map[string]interface{})
if got := assistant["reasoning_content"]; got != "thinking content" {
t.Fatalf("reasoning_content = %#v, want thinking content", got)
}
}
func TestParseCompatFunctionCallsSupportsDSMLToolCalls(t *testing.T) {
calls, cleaned := parseCompatFunctionCalls(`<DSMLtool_calls><DSMLinvoke name="read_file"></DSMLinvoke></DSMLtool_calls>`)
if len(calls) != 1 {
t.Fatalf("calls = %#v, want one tool call", calls)
}
if calls[0].Name != "read_file" {
t.Fatalf("tool name = %q, want read_file", calls[0].Name)
}
if strings.TrimSpace(cleaned) != "" {
t.Fatalf("cleaned = %q, want empty", cleaned)
}
}