diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5b97d69..dd8fbf0 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -67,65 +67,6 @@ jobs: echo "Targets: $build_targets" echo "Variants: $channel_variants" - build-webui: - needs: prepare-release - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Ensure Node.js 20 - shell: bash - run: | - set -euo pipefail - if command -v node >/dev/null 2>&1; then - node_major="$(node -p 'process.versions.node.split(".")[0]')" - else - node_major="" - fi - if [ "$node_major" != "20" ]; then - arch="$(dpkg --print-architecture)" - case "$arch" in - amd64) node_arch="x64" ;; - arm64) node_arch="arm64" ;; - *) - echo "Unsupported runner architecture for Node install: $arch" >&2 - exit 1 - ;; - esac - curl -fsSL "https://nodejs.org/dist/v20.19.5/node-v20.19.5-linux-${node_arch}.tar.xz" -o /tmp/node.tar.xz - sudo rm -rf /usr/local/lib/nodejs - sudo mkdir -p /usr/local/lib/nodejs - sudo tar -xJf /tmp/node.tar.xz -C /usr/local/lib/nodejs - echo "/usr/local/lib/nodejs/node-v20.19.5-linux-${node_arch}/bin" >> "$GITHUB_PATH" - fi - node --version - npm --version - - - name: Install WebUI dependencies - shell: bash - run: | - set -euo pipefail - cd webui - if [ -f package-lock.json ]; then - npm ci - else - npm install - fi - - - name: Build WebUI - shell: bash - run: | - set -euo pipefail - make build-webui - - - name: Upload WebUI dist - uses: actions/upload-artifact@v4 - with: - name: webui-dist - path: webui/dist - if-no-files-found: error - prepare-go-cache: needs: prepare-release runs-on: ubuntu-latest @@ -172,7 +113,6 @@ jobs: build-and-package: needs: - prepare-release - - build-webui - prepare-go-cache runs-on: ubuntu-latest strategy: @@ -203,12 +143,6 @@ jobs: echo "/usr/local/go/bin" >> "$GITHUB_PATH" /usr/local/go/bin/go version - - name: Download WebUI dist - uses: actions/download-artifact@v4 - with: - name: webui-dist - path: webui/dist - - name: Download Go module cache uses: actions/download-artifact@v4 with: @@ -248,8 +182,7 @@ jobs: make package-all \ VERSION="${{ needs.prepare-release.outputs.version }}" \ BUILD_TARGETS="${{ matrix.target }}" \ - CHANNEL_PACKAGE_VARIANTS="${{ needs.prepare-release.outputs.channel_variants }}" \ - SKIP_WEBUI_BUILD=1 + CHANNEL_PACKAGE_VARIANTS="${{ needs.prepare-release.outputs.channel_variants }}" rm -f build/checksums.txt - name: Upload matrix artifacts diff --git a/Makefile b/Makefile index 660113f..ca3951c 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: all build build-variants build-linux-slim build-all build-all-variants build-webui package-all install install-win uninstall clean help test test-docker install-bootstrap-docs sync-embed-workspace sync-embed-workspace-base sync-embed-webui cleanup-embed-workspace test-only clean-test-artifacts dev +.PHONY: all build build-variants build-linux-slim build-all build-all-variants package-all install install-win uninstall clean help test test-docker install-bootstrap-docs sync-embed-workspace cleanup-embed-workspace test-only clean-test-artifacts dev # Build variables BINARY_NAME=clawgo @@ -55,15 +55,9 @@ WORKSPACE_SKILLS_DIR=$(WORKSPACE_DIR)/skills BUILTIN_SKILLS_DIR=$(CURDIR)/skills WORKSPACE_SOURCE_DIR=$(CURDIR)/workspace EMBED_WORKSPACE_DIR=$(CURDIR)/cmd/workspace -EMBED_WEBUI_DIR=$(EMBED_WORKSPACE_DIR)/webui DEV_CONFIG?=$(if $(wildcard $(CURDIR)/config.json),$(CURDIR)/config.json,$(CLAWGO_HOME)/config.json) DEV_ARGS?=--debug gateway run DEV_WORKSPACE?=$(WORKSPACE_DIR) -DEV_WEBUI_DIR?=$(CURDIR)/webui -WEBUI_DIST_DIR=$(DEV_WEBUI_DIR)/dist -WEBUI_PACKAGE_LOCK=$(DEV_WEBUI_DIR)/package-lock.json -NPM?=npm -SKIP_WEBUI_BUILD?=0 # OS detection UNAME_S:=$(shell uname -s) @@ -224,36 +218,6 @@ build-all-variants: sync-embed-workspace done @echo "All variant builds complete" -## build-webui: Install WebUI dependencies when needed and build dist assets -build-webui: - @echo "Building WebUI..." - @if [ ! -d "$(DEV_WEBUI_DIR)" ]; then \ - echo "✗ Missing WebUI directory: $(DEV_WEBUI_DIR)"; \ - exit 1; \ - fi - @if [ "$(SKIP_WEBUI_BUILD)" = "1" ]; then \ - if [ -d "$(WEBUI_DIST_DIR)" ]; then \ - echo "✓ Reusing existing WebUI dist from $(WEBUI_DIST_DIR)"; \ - exit 0; \ - fi; \ - echo "✗ SKIP_WEBUI_BUILD=1 but WebUI dist is missing: $(WEBUI_DIST_DIR)"; \ - exit 1; \ - fi - @if ! command -v "$(NPM)" >/dev/null 2>&1; then \ - echo "✗ npm is required to build the WebUI"; \ - exit 1; \ - fi - @set -e; \ - if [ ! -d "$(DEV_WEBUI_DIR)/node_modules" ]; then \ - echo "Installing WebUI dependencies..."; \ - if [ -f "$(WEBUI_PACKAGE_LOCK)" ]; then \ - (cd "$(DEV_WEBUI_DIR)" && "$(NPM)" ci); \ - else \ - (cd "$(DEV_WEBUI_DIR)" && "$(NPM)" install); \ - fi; \ - fi; \ - (cd "$(DEV_WEBUI_DIR)" && "$(NPM)" run build) - ## package-all: Create compressed archives and checksums for full, no-channel, and per-channel build variants package-all: build-all-variants @echo "Packaging build artifacts..." @@ -290,31 +254,19 @@ package-all: build-all-variants fi @echo "Package complete: $(BUILD_DIR)" -## sync-embed-workspace: Sync workspace seed files and built WebUI into cmd/workspace for go:embed -sync-embed-workspace: sync-embed-workspace-base sync-embed-webui +## sync-embed-workspace: Sync workspace seed files into cmd/workspace for go:embed +sync-embed-workspace: @echo "✓ Embed assets ready in $(EMBED_WORKSPACE_DIR)" -## sync-embed-workspace-base: Sync root workspace files into cmd/workspace for go:embed -sync-embed-workspace-base: - @echo "Syncing workspace seed files for embedding..." @if [ ! -d "$(WORKSPACE_SOURCE_DIR)" ]; then \ echo "✗ Missing source workspace directory: $(WORKSPACE_SOURCE_DIR)"; \ exit 1; \ fi + @echo "Syncing workspace seed files for embedding..." @mkdir -p "$(EMBED_WORKSPACE_DIR)" @rsync -a --delete "$(WORKSPACE_SOURCE_DIR)/" "$(EMBED_WORKSPACE_DIR)/" @echo "✓ Synced workspace to $(EMBED_WORKSPACE_DIR)" -## sync-embed-webui: Build and sync WebUI dist into embedded workspace assets -sync-embed-webui: build-webui - @if [ ! -d "$(WEBUI_DIST_DIR)" ]; then \ - echo "✗ Missing WebUI dist directory: $(WEBUI_DIST_DIR)"; \ - exit 1; \ - fi - @mkdir -p "$(EMBED_WEBUI_DIR)" - @rsync -a --delete "$(WEBUI_DIST_DIR)/" "$(EMBED_WEBUI_DIR)/" - @echo "✓ Synced WebUI dist to $(EMBED_WEBUI_DIR)" - ## cleanup-embed-workspace: Remove synced embed workspace artifacts cleanup-embed-workspace: @rm -rf "$(EMBED_WORKSPACE_DIR)" @@ -447,24 +399,17 @@ deps: run: build @$(BUILD_DIR)/$(BINARY_NAME) $(ARGS) -## dev: Build WebUI, sync workspace, and run the local gateway in foreground for debugging -dev: build-webui sync-embed-workspace +## dev: Sync workspace and run the local gateway in foreground for debugging +dev: sync-embed-workspace @if [ ! -f "$(DEV_CONFIG)" ]; then \ echo "✗ Missing config file: $(DEV_CONFIG)"; \ echo " Override with: make dev DEV_CONFIG=/path/to/config.json"; \ exit 1; \ fi - @if [ ! -d "$(DEV_WEBUI_DIR)" ]; then \ - echo "✗ Missing WebUI directory: $(DEV_WEBUI_DIR)"; \ - exit 1; \ - fi @set -e; trap '$(MAKE) -C $(CURDIR) cleanup-embed-workspace' EXIT; \ - echo "Syncing WebUI dist to $(DEV_WORKSPACE)/webui ..."; \ - mkdir -p "$(DEV_WORKSPACE)/webui"; \ - rsync -a --delete "$(DEV_WEBUI_DIR)/dist/" "$(DEV_WORKSPACE)/webui/"; \ echo "Starting local gateway debug session..."; \ echo " Config: $(DEV_CONFIG)"; \ - echo " WebUI: $(DEV_WORKSPACE)/webui"; \ + echo " Workspace: $(DEV_WORKSPACE)"; \ echo " Args: $(DEV_ARGS)"; \ CLAWGO_CONFIG="$(DEV_CONFIG)" $(GO) run $(GOFLAGS) ./$(CMD_DIR) $(DEV_ARGS) @@ -498,9 +443,7 @@ help: @echo " WORKSPACE_DIR # Workspace directory (default: ~/.clawgo/workspace)" @echo " DEV_CONFIG # Config path for make dev" @echo " DEV_ARGS # CLI args for make dev (default: --debug gateway run)" - @echo " DEV_WORKSPACE # Workspace path for WebUI sync in make dev" - @echo " DEV_WEBUI_DIR # WebUI source dir for make dev (default: ./webui)" - @echo " NPM # npm executable for WebUI build (default: npm)" + @echo " DEV_WORKSPACE # Workspace path used by make dev" @echo " VERSION # Version string (default: git describe)" @echo " STRIP_SYMBOLS # 1=strip debug/symbol info (default: 1)" @echo "" diff --git a/cmd/cmd_gateway.go b/cmd/cmd_gateway.go index c9be473..3588681 100644 --- a/cmd/cmd_gateway.go +++ b/cmd/cmd_gateway.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "net/http" "net/url" "os" "os/exec" @@ -27,6 +28,7 @@ import ( "github.com/YspCoder/clawgo/pkg/providers" "github.com/YspCoder/clawgo/pkg/runtimecfg" "github.com/YspCoder/clawgo/pkg/sentinel" + "github.com/YspCoder/clawgo/pkg/wsrelay" "github.com/pion/webrtc/v4" ) @@ -181,6 +183,24 @@ func gatewayCmd() { registryServer.SetWorkspacePath(cfg.WorkspacePath()) registryServer.SetLogFilePath(cfg.LogFilePath()) registryServer.SetWebUIDir(filepath.Join(cfg.WorkspacePath(), "webui")) + aistudioRelay := wsrelay.NewManager(wsrelay.Options{ + Path: "/v1/ws", + ProviderFactory: func(r *http.Request) (string, error) { + provider := strings.TrimSpace(r.URL.Query().Get("provider")) + if provider == "" { + provider = strings.TrimSpace(r.Header.Get("X-Clawgo-Provider")) + } + if provider == "" { + provider = "aistudio" + } + return strings.ToLower(provider), nil + }, + OnConnected: providers.NotifyAIStudioRelayConnected, + OnDisconnected: providers.NotifyAIStudioRelayDisconnected, + }) + defer func() { _ = aistudioRelay.Stop(context.Background()) }() + providers.SetAIStudioRelayManager(aistudioRelay) + registryServer.SetProtectedRoute(aistudioRelay.Path(), aistudioRelay.Handler()) bindAgentLoopHandlers := func(loop *agent.AgentLoop) { registryServer.SetChatHandler(func(cctx context.Context, sessionKey, content string) (string, error) { if strings.TrimSpace(content) == "" { diff --git a/install.sh b/install.sh index ea596bd..643ec9b 100755 --- a/install.sh +++ b/install.sh @@ -3,7 +3,9 @@ set -euo pipefail OWNER="YspCoder" REPO="clawgo" +WEBUI_REPO="clawgo-web" BIN="clawgo" +WEBUI_ASSET="webui.tar.gz" INSTALL_DIR="/usr/local/bin" VARIANT="${CLAWGO_CHANNEL_VARIANT:-full}" VARIANT_EXPLICIT=0 @@ -20,7 +22,7 @@ Usage: $0 [--variant full|none|telegram|discord|feishu|maixcam|qq|dingtalk|whats Install or upgrade ClawGo from the latest GitHub release. Notes: - - WebUI is embedded in the binary and initialized when you run 'clawgo onboard'. + - The installer can optionally download WebUI from the matching tag in $OWNER/$WEBUI_REPO. - Variant 'none' installs the no-channel build. - OpenClaw migration is offered only when a legacy workspace is detected. EOF @@ -248,6 +250,36 @@ install_binary() { log "Installed $BIN to $INSTALL_DIR/$BIN" } +install_webui() { + local target_dir="$1" + local file="$WEBUI_ASSET" + local url="https://github.com/$OWNER/$WEBUI_REPO/releases/download/$TAG/$file" + local out="$TMPDIR/$file" + local extract_dir="$TMPDIR/webui-extract" + + require_cmd tar + require_cmd rsync + mkdir -p "$extract_dir" "$target_dir" + + log "Downloading optional WebUI package from $OWNER/$WEBUI_REPO@$TAG ..." + if ! curl -fSL "$url" -o "$out"; then + warn "Failed to download $file from $url" + return 1 + fi + + rm -rf "$target_dir" + mkdir -p "$target_dir" + tar -xzf "$out" -C "$extract_dir" + + if [[ -d "$extract_dir/webui" ]]; then + rsync -a --delete "$extract_dir/webui/" "$target_dir/" + else + rsync -a --delete "$extract_dir/" "$target_dir/" + fi + + log "Installed WebUI to $target_dir" +} + migrate_local_openclaw() { local src="${1:-$LEGACY_WORKSPACE_DIR}" local dst="${2:-$WORKSPACE_DIR}" @@ -377,17 +409,12 @@ offer_openclaw_migration() { } offer_onboard() { - log "Refreshing embedded WebUI assets..." - "$INSTALL_DIR/$BIN" onboard --sync-webui >/dev/null 2>&1 || warn "Failed to refresh embedded WebUI assets automatically. You can run 'clawgo onboard --sync-webui' later." - if [[ -f "$CONFIG_PATH" ]]; then log "Existing config detected at $CONFIG_PATH" - log "WebUI assets were refreshed from the embedded bundle." log "Run 'clawgo onboard' only if you want to regenerate config or missing workspace templates." return fi - log "WebUI assets were refreshed from the embedded bundle." if prompt_yes_no "No config found. Run 'clawgo onboard' now?" "N"; then "$INSTALL_DIR/$BIN" onboard else @@ -406,6 +433,13 @@ main() { trap 'rm -rf "$TMPDIR"' EXIT install_binary + if prompt_yes_no "Install optional WebUI from $OWNER/$WEBUI_REPO@$TAG?" "N"; then + local_webui_dir="$HOME/clawgo-webui" + tty_read local_webui_dir "Enter WebUI install directory: " "$local_webui_dir" + install_webui "$local_webui_dir" || warn "WebUI installation failed. You can install it later by downloading $WEBUI_ASSET from the $TAG release." + else + log "Skipped optional WebUI installation." + fi offer_openclaw_migration offer_onboard diff --git a/pkg/api/server.go b/pkg/api/server.go index ca7e219..5825ca7 100644 --- a/pkg/api/server.go +++ b/pkg/api/server.go @@ -76,6 +76,8 @@ type Server struct { whatsAppBase string oauthFlowMu sync.Mutex oauthFlows map[string]*providers.OAuthPendingFlow + extraRoutesMu sync.RWMutex + extraRoutes map[string]http.Handler } var nodesWebsocketUpgrader = websocket.Upgrader{ @@ -100,6 +102,7 @@ func NewServer(host string, port int, token string, mgr *nodes.Manager) *Server liveRuntimeSubs: map[chan []byte]struct{}{}, liveSubagents: map[string]*liveSubagentGroup{}, oauthFlows: map[string]*providers.OAuthPendingFlow{}, + extraRoutes: map[string]http.Handler{}, } } @@ -312,6 +315,19 @@ func (s *Server) SetToolsCatalogHandler(fn func() interface{}) { s.onToolsCatalo func (s *Server) SetWebUIDir(dir string) { s.webUIDir = strings.TrimSpace(dir) } func (s *Server) SetGatewayVersion(v string) { s.gatewayVersion = strings.TrimSpace(v) } func (s *Server) SetWebUIVersion(v string) { s.webuiVersion = strings.TrimSpace(v) } +func (s *Server) SetProtectedRoute(path string, handler http.Handler) { + if s == nil { + return + } + path = strings.TrimSpace(path) + s.extraRoutesMu.Lock() + defer s.extraRoutesMu.Unlock() + if path == "" || handler == nil { + delete(s.extraRoutes, path) + return + } + s.extraRoutes[path] = handler +} func (s *Server) SetNodeWebRTCTransport(t *nodes.WebRTCTransport) { s.nodeWebRTC = t } @@ -489,6 +505,19 @@ func (s *Server) Start(ctx context.Context) error { mux.HandleFunc("/api/logs/stream", s.handleWebUILogsStream) mux.HandleFunc("/api/logs/live", s.handleWebUILogsLive) mux.HandleFunc("/api/logs/recent", s.handleWebUILogsRecent) + s.extraRoutesMu.RLock() + for path, handler := range s.extraRoutes { + routePath := path + routeHandler := handler + mux.Handle(routePath, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !s.checkAuth(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + routeHandler.ServeHTTP(w, r) + })) + } + s.extraRoutesMu.RUnlock() base := strings.TrimRight(strings.TrimSpace(s.whatsAppBase), "/") if base == "" { base = "/whatsapp" diff --git a/pkg/providers/aistudio_provider.go b/pkg/providers/aistudio_provider.go new file mode 100644 index 0000000..2ac0889 --- /dev/null +++ b/pkg/providers/aistudio_provider.go @@ -0,0 +1,271 @@ +package providers + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + "github.com/YspCoder/clawgo/pkg/wsrelay" +) + +type AistudioProvider struct { + base *HTTPProvider + relay *wsrelay.Manager +} + +func NewAistudioProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *AistudioProvider { + normalizedBase := normalizeAPIBase(apiBase) + if normalizedBase == "" { + normalizedBase = geminiBaseURL + } + return &AistudioProvider{ + base: NewHTTPProvider(providerName, apiKey, normalizedBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth), + relay: getAIStudioRelayManager(), + } +} + +func (p *AistudioProvider) GetDefaultModel() string { + if p == nil || p.base == nil { + return "" + } + return p.base.GetDefaultModel() +} + +func (p *AistudioProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, false, nil) + if err != nil { + return nil, err + } + if status != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body)) + } + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body)) + } + return parseGeminiResponse(body) +} + +func (p *AistudioProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) { + body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, true, onDelta) + if err != nil { + return nil, err + } + if status != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body)) + } + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body)) + } + return parseGeminiResponse(body) +} + +func (p *AistudioProvider) CountTokens(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*UsageInfo, error) { + requestBody := p.buildRequestBody(messages, nil, model, options, false) + delete(requestBody, "tools") + delete(requestBody, "toolConfig") + delete(requestBody, "generationConfig") + body, status, ctype, err := p.perform(ctx, p.endpoint(model, "countTokens", false), requestBody, options, false, nil) + if err != nil { + return nil, err + } + if status != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body)) + } + var payload struct { + TotalTokens int `json:"totalTokens"` + } + if err := json.Unmarshal(body, &payload); err != nil { + return nil, fmt.Errorf("invalid countTokens response: %w", err) + } + return &UsageInfo{PromptTokens: payload.TotalTokens, TotalTokens: payload.TotalTokens}, nil +} + +func (p *AistudioProvider) doRequest(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool, onDelta func(string)) ([]byte, int, string, error) { + requestBody := p.buildRequestBody(messages, tools, model, options, stream) + return p.perform(ctx, p.endpoint(model, "generateContent", stream), requestBody, options, stream, onDelta) +} + +func (p *AistudioProvider) perform(ctx context.Context, endpoint string, payload map[string]any, options map[string]interface{}, stream bool, onDelta func(string)) ([]byte, int, string, error) { + if p == nil || p.base == nil { + return nil, 0, "", fmt.Errorf("provider not configured") + } + if p.relay == nil { + p.relay = getAIStudioRelayManager() + } + if p.relay == nil { + return nil, 0, "", fmt.Errorf("aistudio relay not configured") + } + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err) + } + req := &wsrelay.HTTPRequest{ + Method: http.MethodPost, + URL: endpoint, + Headers: http.Header{ + "Content-Type": []string{"application/json"}, + "Accept": []string{"application/json"}, + }, + Body: jsonData, + } + if stream { + req.Headers.Set("Accept", "text/event-stream") + } + channelIDs := aistudioChannelCandidates(p.base.providerName, options) + if len(channelIDs) == 0 { + return nil, 0, "", fmt.Errorf("aistudio relay channel not specified") + } + if !stream { + var lastErr error + for _, channelID := range channelIDs { + resp, err := p.relay.NonStream(ctx, channelID, req) + if err != nil { + recordAIStudioRelayFailure(channelID, err) + lastErr = err + continue + } + if resp.Status >= 200 && resp.Status < 300 { + recordAIStudioRelaySuccess(channelID) + return resp.Body, resp.Status, strings.TrimSpace(resp.Headers.Get("Content-Type")), nil + } + retryErr := fmt.Errorf("status=%d", resp.Status) + recordAIStudioRelayFailure(channelID, retryErr) + lastErr = retryErr + if resp.Status < 500 && resp.Status != http.StatusTooManyRequests { + return resp.Body, resp.Status, strings.TrimSpace(resp.Headers.Get("Content-Type")), nil + } + } + if lastErr == nil { + lastErr = fmt.Errorf("aistudio relay request failed") + } + return nil, 0, "", lastErr + } + if onDelta == nil { + onDelta = func(string) {} + } + var lastErr error + for _, channelID := range channelIDs { + streamCh, err := p.relay.Stream(ctx, channelID, req) + if err != nil { + recordAIStudioRelayFailure(channelID, err) + lastErr = err + continue + } + state := &antigravityStreamState{} + status := http.StatusOK + ctype := "text/event-stream" + var full bytes.Buffer + started := false + retryable := false + failed := false + for event := range streamCh { + if event.Err != nil { + recordAIStudioRelayFailure(channelID, event.Err) + lastErr = event.Err + retryable = !started + failed = true + break + } + switch event.Type { + case wsrelay.MessageTypeStreamStart: + if event.Status > 0 { + status = event.Status + } + if v := strings.TrimSpace(event.Headers.Get("Content-Type")); v != "" { + ctype = v + } + case wsrelay.MessageTypeStreamChunk: + if len(event.Payload) == 0 { + continue + } + started = true + full.Write(event.Payload) + filtered := filterGeminiSSEUsageMetadata(event.Payload) + if delta := state.consume(filtered); delta != "" { + onDelta(delta) + } + case wsrelay.MessageTypeHTTPResp: + if event.Status > 0 { + status = event.Status + } + if v := strings.TrimSpace(event.Headers.Get("Content-Type")); v != "" { + ctype = v + } + if len(event.Payload) > 0 { + if status >= 200 && status < 300 { + recordAIStudioRelaySuccess(channelID) + } else { + recordAIStudioRelayFailure(channelID, fmt.Errorf("status=%d", status)) + } + if status >= 500 || status == http.StatusTooManyRequests { + lastErr = fmt.Errorf("status=%d", status) + retryable = !started + failed = true + break + } + return event.Payload, status, ctype, nil + } + if status >= 200 && status < 300 { + recordAIStudioRelaySuccess(channelID) + } else { + recordAIStudioRelayFailure(channelID, fmt.Errorf("status=%d", status)) + } + if status >= 500 || status == http.StatusTooManyRequests { + lastErr = fmt.Errorf("status=%d", status) + retryable = !started + failed = true + break + } + return state.finalBody(), status, ctype, nil + case wsrelay.MessageTypeStreamEnd: + if status >= 200 && status < 300 { + recordAIStudioRelaySuccess(channelID) + } else { + recordAIStudioRelayFailure(channelID, fmt.Errorf("status=%d", status)) + } + if status >= 500 || status == http.StatusTooManyRequests { + lastErr = fmt.Errorf("status=%d", status) + retryable = !started + failed = true + break + } + return state.finalBody(), status, ctype, nil + } + } + if failed && started { + break + } + if !failed && full.Len() > 0 { + recordAIStudioRelaySuccess(channelID) + return state.finalBody(), status, ctype, nil + } + if !retryable { + break + } + } + if lastErr == nil { + lastErr = fmt.Errorf("wsrelay: stream closed") + } + return nil, 0, "", lastErr +} + +func (p *AistudioProvider) endpoint(model, action string, stream bool) string { + base := geminiBaseURL + if p != nil && p.base != nil && strings.TrimSpace(p.base.apiBase) != "" && !strings.Contains(strings.ToLower(p.base.apiBase), "api.openai.com") { + base = normalizeGeminiBaseURL(p.base.apiBase) + } + baseModel := strings.TrimSpace(qwenBaseModel(model)) + if stream { + return fmt.Sprintf("%s/%s/models/%s:streamGenerateContent?alt=sse", base, geminiAPIVersion, baseModel) + } + return fmt.Sprintf("%s/%s/models/%s:%s", base, geminiAPIVersion, baseModel, action) +} + +func (p *AistudioProvider) buildRequestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool) map[string]any { + gemini := &GeminiProvider{base: p.base} + return gemini.buildRequestBody(messages, tools, model, options, stream) +} diff --git a/pkg/providers/aistudio_provider_test.go b/pkg/providers/aistudio_provider_test.go new file mode 100644 index 0000000..98047f3 --- /dev/null +++ b/pkg/providers/aistudio_provider_test.go @@ -0,0 +1,592 @@ +package providers + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/YspCoder/clawgo/pkg/config" + "github.com/YspCoder/clawgo/pkg/wsrelay" + "github.com/gorilla/websocket" +) + +func TestCreateProviderByNameRoutesAIStudioProvider(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Models.Providers["aistudio"] = config.ProviderConfig{ + Auth: "none", + TimeoutSec: 90, + Models: []string{"gemini-2.5-pro"}, + } + provider, err := CreateProviderByName(cfg, "aistudio") + if err != nil { + t.Fatalf("CreateProviderByName() error = %v", err) + } + if _, ok := provider.(*AistudioProvider); !ok { + t.Fatalf("expected *AistudioProvider, got %T", provider) + } +} + +func TestAistudioProviderChatUsesRelay(t *testing.T) { + manager, serverURL, cleanup := startRelayTestServer(t) + defer cleanup() + SetAIStudioRelayManager(manager) + + reqCh := make(chan string, 1) + stopClient := connectRelayClient(t, serverURL, "aistudio", func(msg wsrelay.Message) []wsrelay.Message { + reqCh <- fmt.Sprintf("%v", msg.Payload["url"]) + return []wsrelay.Message{{ + ID: msg.ID, + Type: wsrelay.MessageTypeHTTPResp, + Payload: map[string]any{ + "status": 200, + "headers": map[string]any{ + "Content-Type": []string{"application/json"}, + }, + "body": `{"candidates":[{"content":{"parts":[{"text":"ok"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}`, + }, + }} + }) + defer stopClient() + + provider := NewAistudioProvider("aistudio", "", "", "gemini-2.5-pro", false, "none", 30*time.Second, nil) + resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}, nil, "gemini-2.5-pro", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if resp.Content != "ok" { + t.Fatalf("expected content ok, got %q", resp.Content) + } + select { + case raw := <-reqCh: + if !strings.Contains(raw, ":generateContent") { + t.Fatalf("expected generateContent endpoint, got %q", raw) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for relay request") + } +} + +func TestAistudioProviderCountTokensUsesRelay(t *testing.T) { + manager, serverURL, cleanup := startRelayTestServer(t) + defer cleanup() + SetAIStudioRelayManager(manager) + + stopClient := connectRelayClient(t, serverURL, "aistudio-count", func(msg wsrelay.Message) []wsrelay.Message { + return []wsrelay.Message{{ + ID: msg.ID, + Type: wsrelay.MessageTypeHTTPResp, + Payload: map[string]any{ + "status": 200, + "headers": map[string]any{ + "Content-Type": []string{"application/json"}, + }, + "body": `{"totalTokens":42}`, + }, + }} + }) + defer stopClient() + + provider := NewAistudioProvider("aistudio", "", "", "gemini-2.5-pro", false, "none", 30*time.Second, nil) + usage, err := provider.CountTokens(context.Background(), []Message{{Role: "user", Content: "hi"}}, nil, "gemini-2.5-pro", map[string]interface{}{"aistudio_channel": "aistudio-count"}) + if err != nil { + t.Fatalf("CountTokens() error = %v", err) + } + if usage.TotalTokens != 42 { + t.Fatalf("expected total tokens 42, got %d", usage.TotalTokens) + } +} + +func TestAIStudioRelayRuntimeTracksConnectedChannels(t *testing.T) { + providerRuntimeRegistry.mu.Lock() + delete(providerRuntimeRegistry.api, "aistudio") + providerRuntimeRegistry.mu.Unlock() + + NotifyAIStudioRelayConnected("aistudio") + NotifyAIStudioRelayConnected("aistudio-alt") + + providerRuntimeRegistry.mu.Lock() + state := providerRuntimeRegistry.api["aistudio"] + providerRuntimeRegistry.mu.Unlock() + + if len(state.CandidateOrder) != 2 { + t.Fatalf("expected 2 relay candidates, got %d", len(state.CandidateOrder)) + } + if state.CandidateOrder[0].Kind != "relay" { + t.Fatalf("expected relay candidate kind, got %q", state.CandidateOrder[0].Kind) + } + + NotifyAIStudioRelayDisconnected("aistudio-alt", nil) + + providerRuntimeRegistry.mu.Lock() + state = providerRuntimeRegistry.api["aistudio"] + providerRuntimeRegistry.mu.Unlock() + + if len(state.CandidateOrder) != 1 || state.CandidateOrder[0].Target != "aistudio" { + t.Fatalf("unexpected candidate order after disconnect: %+v", state.CandidateOrder) + } + + NotifyAIStudioRelayDisconnected("aistudio", nil) +} + +func TestGetProviderRuntimeSnapshotIncludesAIStudioRelayAccounts(t *testing.T) { + providerRuntimeRegistry.mu.Lock() + delete(providerRuntimeRegistry.api, "aistudio") + providerRuntimeRegistry.mu.Unlock() + + NotifyAIStudioRelayConnected("aistudio") + defer NotifyAIStudioRelayDisconnected("aistudio", nil) + + cfg := config.DefaultConfig() + cfg.Models.Providers["aistudio"] = config.ProviderConfig{ + TimeoutSec: 30, + Models: []string{"gemini-2.5-pro"}, + } + + snapshot := GetProviderRuntimeSnapshot(cfg) + items, _ := snapshot["items"].([]map[string]interface{}) + if len(items) == 0 { + t.Fatal("expected snapshot items") + } + + var found map[string]interface{} + for _, item := range items { + if strings.TrimSpace(fmt.Sprintf("%v", item["name"])) == "aistudio" { + found = item + break + } + } + if found == nil { + t.Fatal("expected aistudio snapshot item") + } + + accounts, _ := found["oauth_accounts"].([]OAuthAccountInfo) + if len(accounts) != 1 || accounts[0].AccountLabel != "aistudio" { + t.Fatalf("unexpected relay accounts: %+v", accounts) + } +} + +func TestAIStudioRelayAccountsUseLastSuccessAsRefresh(t *testing.T) { + now := time.Now().UTC().Truncate(time.Second) + aistudioRelayRegistry.mu.Lock() + aistudioRelayRegistry.connected = map[string]time.Time{"aistudio": now.Add(-10 * time.Minute)} + aistudioRelayRegistry.succeeded = map[string]time.Time{"aistudio": now} + aistudioRelayRegistry.mu.Unlock() + + accounts := listAIStudioRelayAccounts() + if len(accounts) != 1 { + t.Fatalf("expected one account, got %d", len(accounts)) + } + if accounts[0].LastRefresh != now.Format(time.RFC3339) { + t.Fatalf("expected last refresh %s, got %s", now.Format(time.RFC3339), accounts[0].LastRefresh) + } +} + +func TestAIStudioRelayRuntimeRecordsFailureAndRecovery(t *testing.T) { + providerRuntimeRegistry.mu.Lock() + delete(providerRuntimeRegistry.api, "aistudio") + providerRuntimeRegistry.mu.Unlock() + + NotifyAIStudioRelayConnected("aistudio") + defer NotifyAIStudioRelayDisconnected("aistudio", nil) + + recordAIStudioRelayFailure("aistudio", fmt.Errorf("boom")) + + providerRuntimeRegistry.mu.Lock() + state := providerRuntimeRegistry.api["aistudio"] + providerRuntimeRegistry.mu.Unlock() + + if state.API.FailureCount == 0 { + t.Fatal("expected api failure count to increase") + } + if len(state.RecentErrors) == 0 { + t.Fatal("expected recent errors to be recorded") + } + if len(state.CandidateOrder) == 0 || state.CandidateOrder[0].Status != "cooldown" { + t.Fatalf("expected relay candidate cooldown, got %+v", state.CandidateOrder) + } + + recordAIStudioRelaySuccess("aistudio") + + providerRuntimeRegistry.mu.Lock() + state = providerRuntimeRegistry.api["aistudio"] + providerRuntimeRegistry.mu.Unlock() + + if state.LastSuccess == nil { + t.Fatal("expected last success event") + } + if len(state.CandidateOrder) == 0 || state.CandidateOrder[0].Status != "ready" { + t.Fatalf("expected relay candidate recovery, got %+v", state.CandidateOrder) + } +} + +func TestAIStudioChannelIDPrefersHealthyAvailableRelay(t *testing.T) { + providerRuntimeRegistry.mu.Lock() + providerRuntimeRegistry.api["aistudio"] = providerRuntimeState{ + CandidateOrder: []providerRuntimeCandidate{ + {Kind: "relay", Target: "aistudio-bad", Available: false, Status: "cooldown", CooldownUntil: time.Now().Add(5 * time.Minute).Format(time.RFC3339), HealthScore: 90}, + {Kind: "relay", Target: "aistudio-good", Available: true, Status: "ready", HealthScore: 100}, + }, + } + providerRuntimeRegistry.mu.Unlock() + + got := aistudioChannelID("aistudio", nil) + if got != "aistudio-good" { + t.Fatalf("expected aistudio-good, got %q", got) + } +} + +func TestAIStudioChannelIDExplicitOptionWins(t *testing.T) { + providerRuntimeRegistry.mu.Lock() + providerRuntimeRegistry.api["aistudio"] = providerRuntimeState{ + CandidateOrder: []providerRuntimeCandidate{ + {Kind: "relay", Target: "aistudio-good", Available: true, Status: "ready", HealthScore: 100}, + }, + } + providerRuntimeRegistry.mu.Unlock() + + got := aistudioChannelID("aistudio", map[string]interface{}{"aistudio_channel": "manual"}) + if got != "manual" { + t.Fatalf("expected explicit channel manual, got %q", got) + } +} + +func TestAIStudioChannelIDPrefersMostRecentSuccessfulRelay(t *testing.T) { + providerRuntimeRegistry.mu.Lock() + providerRuntimeRegistry.api["aistudio"] = providerRuntimeState{ + CandidateOrder: []providerRuntimeCandidate{ + {Kind: "relay", Target: "aistudio-a", Available: true, Status: "ready", HealthScore: 100}, + {Kind: "relay", Target: "aistudio-b", Available: true, Status: "ready", HealthScore: 100}, + }, + } + providerRuntimeRegistry.mu.Unlock() + + aistudioRelayRegistry.mu.Lock() + if aistudioRelayRegistry.succeeded == nil { + aistudioRelayRegistry.succeeded = map[string]time.Time{} + } + aistudioRelayRegistry.succeeded["aistudio-a"] = time.Now().Add(-1 * time.Minute) + aistudioRelayRegistry.succeeded["aistudio-b"] = time.Now() + aistudioRelayRegistry.mu.Unlock() + + got := aistudioChannelID("aistudio", nil) + if got != "aistudio-b" { + t.Fatalf("expected most recent successful relay aistudio-b, got %q", got) + } +} + +func TestAistudioProviderChatFailsOverToNextRelay(t *testing.T) { + manager, serverURL, cleanup := startRelayTestServer(t) + defer cleanup() + SetAIStudioRelayManager(manager) + aistudioRelayRegistry.mu.Lock() + aistudioRelayRegistry.succeeded = map[string]time.Time{} + aistudioRelayRegistry.mu.Unlock() + + providerRuntimeRegistry.mu.Lock() + providerRuntimeRegistry.api["aistudio"] = providerRuntimeState{ + CandidateOrder: []providerRuntimeCandidate{ + {Kind: "relay", Target: "aistudio-first", Available: true, Status: "ready", HealthScore: 100}, + {Kind: "relay", Target: "aistudio-second", Available: true, Status: "ready", HealthScore: 90}, + }, + } + providerRuntimeRegistry.mu.Unlock() + + connectRelayClient(t, serverURL, "aistudio-first", func(msg wsrelay.Message) []wsrelay.Message { + return []wsrelay.Message{{ + ID: msg.ID, + Type: wsrelay.MessageTypeHTTPResp, + Payload: map[string]any{ + "status": 503, + "headers": map[string]any{ + "Content-Type": []string{"application/json"}, + }, + "body": `{"error":"no capacity"}`, + }, + }} + }) + stopSecond := connectRelayClient(t, serverURL, "aistudio-second", func(msg wsrelay.Message) []wsrelay.Message { + return []wsrelay.Message{{ + ID: msg.ID, + Type: wsrelay.MessageTypeHTTPResp, + Payload: map[string]any{ + "status": 200, + "headers": map[string]any{ + "Content-Type": []string{"application/json"}, + }, + "body": `{"candidates":[{"content":{"parts":[{"text":"ok-2"}]},"finishReason":"STOP"}]}`, + }, + }} + }) + defer stopSecond() + + provider := NewAistudioProvider("aistudio", "", "", "gemini-2.5-pro", false, "none", 30*time.Second, nil) + resp, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}, nil, "gemini-2.5-pro", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if resp.Content != "ok-2" { + t.Fatalf("expected failover response ok-2, got %q", resp.Content) + } +} + +func TestAistudioProviderChatExplicitChannelDoesNotFailOver(t *testing.T) { + manager, serverURL, cleanup := startRelayTestServer(t) + defer cleanup() + SetAIStudioRelayManager(manager) + aistudioRelayRegistry.mu.Lock() + aistudioRelayRegistry.succeeded = map[string]time.Time{} + aistudioRelayRegistry.mu.Unlock() + + stopFirst := connectRelayClient(t, serverURL, "manual", func(msg wsrelay.Message) []wsrelay.Message { + return []wsrelay.Message{{ + ID: msg.ID, + Type: wsrelay.MessageTypeHTTPResp, + Payload: map[string]any{ + "status": 503, + "headers": map[string]any{ + "Content-Type": []string{"application/json"}, + }, + "body": `{"error":"no capacity"}`, + }, + }} + }) + defer stopFirst() + + stopSecond := connectRelayClient(t, serverURL, "aistudio-second", func(msg wsrelay.Message) []wsrelay.Message { + return []wsrelay.Message{{ + ID: msg.ID, + Type: wsrelay.MessageTypeHTTPResp, + Payload: map[string]any{ + "status": 200, + "headers": map[string]any{ + "Content-Type": []string{"application/json"}, + }, + "body": `{"candidates":[{"content":{"parts":[{"text":"ok-2"}]},"finishReason":"STOP"}]}`, + }, + }} + }) + defer stopSecond() + + provider := NewAistudioProvider("aistudio", "", "", "gemini-2.5-pro", false, "none", 30*time.Second, nil) + _, err := provider.Chat(context.Background(), []Message{{Role: "user", Content: "hi"}}, nil, "gemini-2.5-pro", map[string]interface{}{"aistudio_channel": "manual"}) + if err == nil { + t.Fatal("expected explicit relay failure without failover") + } +} + +func TestAistudioProviderStreamFailsOverBeforeFirstChunk(t *testing.T) { + manager, serverURL, cleanup := startRelayTestServer(t) + defer cleanup() + SetAIStudioRelayManager(manager) + aistudioRelayRegistry.mu.Lock() + aistudioRelayRegistry.succeeded = map[string]time.Time{} + aistudioRelayRegistry.mu.Unlock() + + providerRuntimeRegistry.mu.Lock() + providerRuntimeRegistry.api["aistudio"] = providerRuntimeState{ + CandidateOrder: []providerRuntimeCandidate{ + {Kind: "relay", Target: "aistudio-first", Available: true, Status: "ready", HealthScore: 100}, + {Kind: "relay", Target: "aistudio-second", Available: true, Status: "ready", HealthScore: 90}, + }, + } + providerRuntimeRegistry.mu.Unlock() + + stopFirst := connectRelayClient(t, serverURL, "aistudio-first", func(msg wsrelay.Message) []wsrelay.Message { + return []wsrelay.Message{ + { + ID: msg.ID, + Type: wsrelay.MessageTypeStreamStart, + Payload: map[string]any{ + "status": 503, + "headers": map[string]any{ + "Content-Type": []string{"text/event-stream"}, + }, + }, + }, + { + ID: msg.ID, + Type: wsrelay.MessageTypeStreamEnd, + }, + } + }) + defer stopFirst() + stopSecond := connectRelayClient(t, serverURL, "aistudio-second", func(msg wsrelay.Message) []wsrelay.Message { + return []wsrelay.Message{ + { + ID: msg.ID, + Type: wsrelay.MessageTypeStreamStart, + Payload: map[string]any{ + "status": 200, + "headers": map[string]any{ + "Content-Type": []string{"text/event-stream"}, + }, + }, + }, + { + ID: msg.ID, + Type: wsrelay.MessageTypeStreamChunk, + Payload: map[string]any{ + "data": `{"candidates":[{"content":{"parts":[{"text":"ok-stream"}]}}]}`, + }, + }, + { + ID: msg.ID, + Type: wsrelay.MessageTypeStreamEnd, + }, + } + }) + defer stopSecond() + + provider := NewAistudioProvider("aistudio", "", "", "gemini-2.5-pro", false, "none", 30*time.Second, nil) + var deltas []string + resp, err := provider.ChatStream(context.Background(), []Message{{Role: "user", Content: "hi"}}, nil, "gemini-2.5-pro", nil, func(delta string) { + deltas = append(deltas, delta) + }) + if err != nil { + t.Fatalf("ChatStream() error = %v", err) + } + if resp.Content != "ok-stream" { + t.Fatalf("expected failover stream content ok-stream, got %q", resp.Content) + } + if len(deltas) == 0 { + t.Fatal("expected stream deltas after failover") + } +} + +func TestAistudioProviderStreamDoesNotFailOverAfterChunk(t *testing.T) { + manager, serverURL, cleanup := startRelayTestServer(t) + defer cleanup() + SetAIStudioRelayManager(manager) + aistudioRelayRegistry.mu.Lock() + aistudioRelayRegistry.succeeded = map[string]time.Time{} + aistudioRelayRegistry.mu.Unlock() + + providerRuntimeRegistry.mu.Lock() + providerRuntimeRegistry.api["aistudio"] = providerRuntimeState{ + CandidateOrder: []providerRuntimeCandidate{ + {Kind: "relay", Target: "aistudio-first", Available: true, Status: "ready", HealthScore: 100}, + {Kind: "relay", Target: "aistudio-second", Available: true, Status: "ready", HealthScore: 90}, + }, + } + providerRuntimeRegistry.mu.Unlock() + + stopFirst := connectRelayClient(t, serverURL, "aistudio-first", func(msg wsrelay.Message) []wsrelay.Message { + return []wsrelay.Message{ + { + ID: msg.ID, + Type: wsrelay.MessageTypeStreamStart, + Payload: map[string]any{ + "status": 200, + "headers": map[string]any{ + "Content-Type": []string{"text/event-stream"}, + }, + }, + }, + { + ID: msg.ID, + Type: wsrelay.MessageTypeStreamChunk, + Payload: map[string]any{ + "data": `{"candidates":[{"content":{"parts":[{"text":"partial"}]}}]}`, + }, + }, + { + ID: msg.ID, + Type: wsrelay.MessageTypeError, + Payload: map[string]any{"error": "stream broke", "status": 502.0}, + }, + } + }) + defer stopFirst() + stopSecond := connectRelayClient(t, serverURL, "aistudio-second", func(msg wsrelay.Message) []wsrelay.Message { + return []wsrelay.Message{ + { + ID: msg.ID, + Type: wsrelay.MessageTypeStreamStart, + Payload: map[string]any{ + "status": 200, + "headers": map[string]any{ + "Content-Type": []string{"text/event-stream"}, + }, + }, + }, + { + ID: msg.ID, + Type: wsrelay.MessageTypeStreamChunk, + Payload: map[string]any{ + "data": `{"candidates":[{"content":{"parts":[{"text":"should-not-use"}]}}]}`, + }, + }, + { + ID: msg.ID, + Type: wsrelay.MessageTypeStreamEnd, + }, + } + }) + defer stopSecond() + + provider := NewAistudioProvider("aistudio", "", "", "gemini-2.5-pro", false, "none", 30*time.Second, nil) + _, err := provider.ChatStream(context.Background(), []Message{{Role: "user", Content: "hi"}}, nil, "gemini-2.5-pro", nil, nil) + if err == nil { + t.Fatal("expected stream error without mid-stream failover") + } + if strings.Contains(err.Error(), "should-not-use") { + t.Fatalf("unexpected failover after stream started: %v", err) + } +} + +func startRelayTestServer(t *testing.T) (*wsrelay.Manager, string, func()) { + t.Helper() + manager := wsrelay.NewManager(wsrelay.Options{ + Path: "/v1/ws", + ProviderFactory: func(r *http.Request) (string, error) { + return strings.ToLower(strings.TrimSpace(r.URL.Query().Get("provider"))), nil + }, + }) + srv := httptest.NewServer(manager.Handler()) + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + manager.Path() + cleanup := func() { + _ = manager.Stop(context.Background()) + srv.Close() + SetAIStudioRelayManager(nil) + } + return manager, wsURL, cleanup +} + +func connectRelayClient(t *testing.T, wsURL, provider string, handle func(wsrelay.Message) []wsrelay.Message) func() { + t.Helper() + u, err := url.Parse(wsURL) + if err != nil { + t.Fatalf("url.Parse() error = %v", err) + } + q := u.Query() + q.Set("provider", provider) + u.RawQuery = q.Encode() + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + t.Fatalf("Dial() error = %v", err) + } + done := make(chan struct{}) + go func() { + defer close(done) + for { + var msg wsrelay.Message + if err := conn.ReadJSON(&msg); err != nil { + return + } + for _, out := range handle(msg) { + if err := conn.WriteJSON(out); err != nil { + return + } + } + } + }() + return func() { + _ = conn.Close() + <-done + } +} diff --git a/pkg/providers/aistudio_relay.go b/pkg/providers/aistudio_relay.go new file mode 100644 index 0000000..ee9947e --- /dev/null +++ b/pkg/providers/aistudio_relay.go @@ -0,0 +1,367 @@ +package providers + +import ( + "fmt" + "sort" + "strings" + "sync" + "time" + + "github.com/YspCoder/clawgo/pkg/wsrelay" +) + +var aistudioRelayRegistry struct { + mu sync.RWMutex + manager *wsrelay.Manager + connected map[string]time.Time + succeeded map[string]time.Time +} + +func SetAIStudioRelayManager(manager *wsrelay.Manager) { + aistudioRelayRegistry.mu.Lock() + aistudioRelayRegistry.manager = manager + if aistudioRelayRegistry.connected == nil { + aistudioRelayRegistry.connected = map[string]time.Time{} + } + if aistudioRelayRegistry.succeeded == nil { + aistudioRelayRegistry.succeeded = map[string]time.Time{} + } + aistudioRelayRegistry.mu.Unlock() +} + +func getAIStudioRelayManager() *wsrelay.Manager { + aistudioRelayRegistry.mu.RLock() + defer aistudioRelayRegistry.mu.RUnlock() + return aistudioRelayRegistry.manager +} + +func aistudioChannelID(providerName string, options map[string]interface{}) string { + channels := aistudioChannelCandidates(providerName, options) + if len(channels) > 0 { + return channels[0] + } + return "" +} + +func aistudioChannelCandidates(providerName string, options map[string]interface{}) []string { + for _, key := range []string{"aistudio_channel", "aistudio_provider", "relay_provider"} { + if value, ok := stringOption(options, key); ok && strings.TrimSpace(value) != "" { + return []string{strings.ToLower(strings.TrimSpace(value))} + } + } + if runtimeSelected := preferredAIStudioRelayChannels(); len(runtimeSelected) > 0 { + return runtimeSelected + } + if fallback := strings.ToLower(strings.TrimSpace(providerName)); fallback != "" { + return []string{fallback} + } + return nil +} + +func preferredAIStudioRelayChannels() []string { + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api["aistudio"] + candidates := append([]providerRuntimeCandidate(nil), state.CandidateOrder...) + sortAIStudioRelayCandidates(candidates) + out := make([]string, 0, len(candidates)) + fallbacks := make([]string, 0, len(candidates)) + for _, candidate := range candidates { + if candidate.Kind != "relay" { + continue + } + target := strings.TrimSpace(candidate.Target) + if target == "" { + continue + } + fallbacks = append(fallbacks, target) + if !candidate.Available { + continue + } + if cooldownActive(candidate.CooldownUntil) { + continue + } + out = append(out, target) + } + if len(out) > 0 { + return out + } + return fallbacks +} + +func NotifyAIStudioRelayConnected(channelID string) { + channelID = strings.ToLower(strings.TrimSpace(channelID)) + if channelID == "" { + return + } + aistudioRelayRegistry.mu.Lock() + if aistudioRelayRegistry.connected == nil { + aistudioRelayRegistry.connected = map[string]time.Time{} + } + if aistudioRelayRegistry.succeeded == nil { + aistudioRelayRegistry.succeeded = map[string]time.Time{} + } + aistudioRelayRegistry.connected[channelID] = time.Now().UTC() + channels := aistudioRelayChannelsLocked() + aistudioRelayRegistry.mu.Unlock() + updateAIStudioRelayRuntime(channels) + recordProviderRuntimeChange("aistudio", "relay", channelID, "relay_connected", "aistudio websocket relay connected") +} + +func NotifyAIStudioRelayDisconnected(channelID string, cause error) { + channelID = strings.ToLower(strings.TrimSpace(channelID)) + if channelID == "" { + return + } + aistudioRelayRegistry.mu.Lock() + if aistudioRelayRegistry.connected != nil { + delete(aistudioRelayRegistry.connected, channelID) + } + if aistudioRelayRegistry.succeeded != nil { + delete(aistudioRelayRegistry.succeeded, channelID) + } + channels := aistudioRelayChannelsLocked() + aistudioRelayRegistry.mu.Unlock() + updateAIStudioRelayRuntime(channels) + detail := "aistudio websocket relay disconnected" + if cause != nil { + detail = fmt.Sprintf("%s: %v", detail, cause) + } + recordProviderRuntimeChange("aistudio", "relay", channelID, "relay_disconnected", detail) +} + +func updateAIStudioRelayRuntime(channels []string) { + candidates := make([]providerRuntimeCandidate, 0, len(channels)) + for _, channelID := range channels { + health, failures, cooldown, _ := aistudioRelayHealth(channelID) + status := "ready" + available := true + if cooldown != "" { + status = "cooldown" + available = false + } + candidates = append(candidates, providerRuntimeCandidate{ + Kind: "relay", + Target: channelID, + Available: available, + Status: status, + CooldownUntil: cooldown, + HealthScore: health, + FailureCount: failures, + }) + } + sortAIStudioRelayCandidates(candidates) + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api["aistudio"] + if !providerCandidatesEqual(state.CandidateOrder, candidates) { + state.RecentChanges = appendRuntimeEvent(state.RecentChanges, providerRuntimeEvent{ + When: time.Now().Format(time.RFC3339), + Kind: "relay", + Target: "aistudio", + Reason: "candidate_order_changed", + Detail: candidateOrderChangeDetail(state.CandidateOrder, candidates), + }, runtimeEventLimit(state)) + } + state.CandidateOrder = candidates + persistProviderRuntimeLocked("aistudio", state) + providerRuntimeRegistry.api["aistudio"] = state +} + +func aistudioRelayChannelsLocked() []string { + out := make([]string, 0, len(aistudioRelayRegistry.connected)) + for channelID := range aistudioRelayRegistry.connected { + out = append(out, channelID) + } + sort.Strings(out) + return out +} + +func listAIStudioRelayAccounts() []OAuthAccountInfo { + aistudioRelayRegistry.mu.RLock() + defer aistudioRelayRegistry.mu.RUnlock() + if len(aistudioRelayRegistry.connected) == 0 { + return nil + } + channels := aistudioRelayChannelsLocked() + out := make([]OAuthAccountInfo, 0, len(channels)) + for _, channelID := range channels { + connectedAt := aistudioRelayRegistry.connected[channelID] + health, failures, cooldown, lastSuccess := aistudioRelayHealth(channelID) + lastRefresh := connectedAt.Format(time.RFC3339) + if !lastSuccess.IsZero() { + lastRefresh = lastSuccess.Format(time.RFC3339) + } + out = append(out, OAuthAccountInfo{ + Email: channelID, + AccountID: channelID, + AccountLabel: channelID, + LastRefresh: lastRefresh, + HealthScore: health, + FailureCount: failures, + CooldownUntil: cooldown, + PlanType: "relay", + QuotaSource: "relay", + BalanceLabel: "connected", + }) + } + return out +} + +func aistudioRelayHealth(channelID string) (health int, failures int, cooldown string, lastSuccess time.Time) { + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api["aistudio"] + if state.API.HealthScore <= 0 { + health = 100 + } else { + health = state.API.HealthScore + } + for _, candidate := range state.CandidateOrder { + if candidate.Kind == "relay" && candidate.Target == channelID { + if candidate.HealthScore > 0 { + health = candidate.HealthScore + } + failures = candidate.FailureCount + cooldown = strings.TrimSpace(candidate.CooldownUntil) + break + } + } + aistudioRelayRegistry.mu.RLock() + lastSuccess = aistudioRelayRegistry.succeeded[channelID] + aistudioRelayRegistry.mu.RUnlock() + return health, failures, cooldown, lastSuccess +} + +func recordAIStudioRelaySuccess(channelID string) { + aistudioRelayRegistry.mu.Lock() + if aistudioRelayRegistry.succeeded == nil { + aistudioRelayRegistry.succeeded = map[string]time.Time{} + } + aistudioRelayRegistry.succeeded[channelID] = time.Now().UTC() + aistudioRelayRegistry.mu.Unlock() + updateAIStudioRelayAttempt(channelID, "", true) +} + +func recordAIStudioRelayFailure(channelID string, err error) { + reason := "relay_error" + if err != nil && strings.TrimSpace(err.Error()) != "" { + reason = strings.TrimSpace(err.Error()) + } + updateAIStudioRelayAttempt(channelID, reason, false) +} + +func updateAIStudioRelayAttempt(channelID, reason string, success bool) { + channelID = strings.ToLower(strings.TrimSpace(channelID)) + if channelID == "" { + return + } + now := time.Now() + providerRuntimeRegistry.mu.Lock() + defer providerRuntimeRegistry.mu.Unlock() + state := providerRuntimeRegistry.api["aistudio"] + if state.API.HealthScore <= 0 { + state.API.HealthScore = 100 + } + found := false + for i := range state.CandidateOrder { + candidate := &state.CandidateOrder[i] + if candidate.Kind != "relay" || candidate.Target != channelID { + continue + } + found = true + if success { + candidate.HealthScore = minInt(100, maxInt(candidate.HealthScore, 100)+3) + candidate.FailureCount = 0 + candidate.CooldownUntil = "" + candidate.Available = true + candidate.Status = "ready" + } else { + if candidate.HealthScore <= 0 { + candidate.HealthScore = 100 + } + candidate.HealthScore = maxInt(1, candidate.HealthScore-20) + candidate.FailureCount++ + candidate.CooldownUntil = now.Add(5 * time.Minute).Format(time.RFC3339) + candidate.Available = false + candidate.Status = "cooldown" + state.RecentErrors = appendRuntimeEvent(state.RecentErrors, providerRuntimeEvent{ + When: now.Format(time.RFC3339), + Kind: "relay", + Target: channelID, + Reason: reason, + }, runtimeEventLimit(state)) + } + break + } + if !found { + candidate := providerRuntimeCandidate{ + Kind: "relay", + Target: channelID, + Available: success, + Status: "ready", + HealthScore: 100, + } + if !success { + candidate.Status = "cooldown" + candidate.Available = false + candidate.FailureCount = 1 + candidate.HealthScore = 80 + candidate.CooldownUntil = now.Add(5 * time.Minute).Format(time.RFC3339) + state.RecentErrors = appendRuntimeEvent(state.RecentErrors, providerRuntimeEvent{ + When: now.Format(time.RFC3339), + Kind: "relay", + Target: channelID, + Reason: reason, + }, runtimeEventLimit(state)) + } + state.CandidateOrder = append(state.CandidateOrder, candidate) + sortRuntimeCandidates(state.CandidateOrder) + } + if success { + state.API.HealthScore = minInt(100, state.API.HealthScore+2) + state.API.CooldownUntil = "" + state.LastSuccess = &providerRuntimeEvent{ + When: now.Format(time.RFC3339), + Kind: "relay", + Target: channelID, + Reason: "success", + } + state.RecentHits = appendRuntimeEvent(state.RecentHits, *state.LastSuccess, runtimeEventLimit(state)) + } else { + state.API.HealthScore = maxInt(1, state.API.HealthScore-10) + state.API.FailureCount++ + state.API.LastFailure = reason + state.API.CooldownUntil = now.Add(5 * time.Minute).Format(time.RFC3339) + } + persistProviderRuntimeLocked("aistudio", state) + providerRuntimeRegistry.api["aistudio"] = state +} + +func sortAIStudioRelayCandidates(items []providerRuntimeCandidate) { + sort.SliceStable(items, func(i, j int) bool { + left := items[i] + right := items[j] + if left.Available != right.Available { + return left.Available + } + leftSuccess := aistudioRelayLastSuccess(left.Target) + rightSuccess := aistudioRelayLastSuccess(right.Target) + if !leftSuccess.Equal(rightSuccess) { + return leftSuccess.After(rightSuccess) + } + if left.HealthScore != right.HealthScore { + return left.HealthScore > right.HealthScore + } + return left.Target < right.Target + }) +} + +func aistudioRelayLastSuccess(channelID string) time.Time { + aistudioRelayRegistry.mu.RLock() + defer aistudioRelayRegistry.mu.RUnlock() + if aistudioRelayRegistry.succeeded == nil { + return time.Time{} + } + return aistudioRelayRegistry.succeeded[channelID] +} diff --git a/pkg/providers/antigravity_provider.go b/pkg/providers/antigravity_provider.go index c2acd41..22fb52c 100644 --- a/pkg/providers/antigravity_provider.go +++ b/pkg/providers/antigravity_provider.go @@ -15,6 +15,7 @@ import ( const ( antigravityDailyBaseURL = "https://daily-cloudcode-pa.googleapis.com" antigravitySandboxBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com" + antigravityProdBaseURL = "https://cloudcode-pa.googleapis.com" ) type AntigravityProvider struct { @@ -38,6 +39,64 @@ func (p *AntigravityProvider) GetDefaultModel() string { return p.base.GetDefaultModel() } +func (p *AntigravityProvider) CountTokens(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*UsageInfo, error) { + if p == nil || p.base == nil { + return nil, fmt.Errorf("provider not configured") + } + attempts, err := p.base.authAttempts(ctx) + if err != nil { + return nil, err + } + var lastBody []byte + var lastStatus int + var lastType string + for _, attempt := range attempts { + for _, baseURL := range p.baseURLs() { + requestBody := p.buildRequestBody(messages, tools, model, options, attempt.session, false) + delete(requestBody, "project") + delete(requestBody, "model") + request := mapFromAny(requestBody["request"]) + delete(request, "safetySettings") + requestBody["request"] = request + body, status, ctype, reqErr := p.performCountTokensAttempt(ctx, p.countTokensEndpoint(baseURL), requestBody, attempt) + if reqErr != nil { + if strings.Contains(strings.ToLower(reqErr.Error()), "context canceled") || strings.Contains(strings.ToLower(reqErr.Error()), "deadline exceeded") { + return nil, reqErr + } + lastBody, lastStatus, lastType = nil, 0, "" + continue + } + lastBody, lastStatus, lastType = body, status, ctype + if status == http.StatusTooManyRequests { + continue + } + reason, retry := classifyOAuthFailure(status, body) + if retry { + if attempt.kind == "oauth" && attempt.session != nil && p.base.oauth != nil { + p.base.oauth.markExhausted(attempt.session, reason) + recordProviderOAuthError(p.base.providerName, attempt.session, reason) + } + if attempt.kind == "api_key" { + p.base.markAPIKeyFailure(reason) + } + break + } + if status != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body)) + } + var payload struct { + TotalTokens int `json:"totalTokens"` + } + if err := json.Unmarshal(body, &payload); err != nil { + return nil, fmt.Errorf("invalid countTokens response: %w", err) + } + p.base.markAttemptSuccess(attempt) + return &UsageInfo{PromptTokens: payload.TotalTokens, TotalTokens: payload.TotalTokens}, nil + } + } + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", lastStatus, lastType, previewResponseBody(lastBody)) +} + func (p *AntigravityProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, false, nil) if err != nil { @@ -81,31 +140,39 @@ func (p *AntigravityProvider) doRequest(ctx context.Context, messages []Message, for _, baseURL := range p.baseURLs() { requestBody := p.buildRequestBody(messages, tools, model, options, attempt.session, stream) endpoint := p.endpoint(baseURL, stream) - body, status, ctype, reqErr := p.performAttempt(ctx, endpoint, requestBody, attempt, stream, onDelta) - if reqErr != nil { - if strings.Contains(strings.ToLower(reqErr.Error()), "context canceled") || strings.Contains(strings.ToLower(reqErr.Error()), "deadline exceeded") { - return nil, 0, "", reqErr + for retryAttempt := 0; retryAttempt < 3; retryAttempt++ { + body, status, ctype, reqErr := p.performAttempt(ctx, endpoint, requestBody, attempt, stream, onDelta) + if reqErr != nil { + if strings.Contains(strings.ToLower(reqErr.Error()), "context canceled") || strings.Contains(strings.ToLower(reqErr.Error()), "deadline exceeded") { + return nil, 0, "", reqErr + } + lastBody, lastStatus, lastType = nil, 0, "" + break } - lastBody, lastStatus, lastType = nil, 0, "" - continue - } - lastBody, lastStatus, lastType = body, status, ctype - if status == http.StatusTooManyRequests || status == http.StatusServiceUnavailable || status == http.StatusBadGateway { - continue - } - reason, retry := classifyOAuthFailure(status, body) - if retry { - if attempt.kind == "oauth" && attempt.session != nil && p.base.oauth != nil { - p.base.oauth.markExhausted(attempt.session, reason) - recordProviderOAuthError(p.base.providerName, attempt.session, reason) + lastBody, lastStatus, lastType = body, status, ctype + if antigravityShouldRetryNoCapacity(status, body) && retryAttempt < 2 { + if err := antigravityWait(ctx, antigravityNoCapacityRetryDelay(retryAttempt)); err != nil { + return nil, 0, "", err + } + continue } - if attempt.kind == "api_key" { - p.base.markAPIKeyFailure(reason) + if status == http.StatusTooManyRequests || status == http.StatusServiceUnavailable || status == http.StatusBadGateway { + break } - break + reason, retry := classifyOAuthFailure(status, body) + if retry { + if attempt.kind == "oauth" && attempt.session != nil && p.base.oauth != nil { + p.base.oauth.markExhausted(attempt.session, reason) + recordProviderOAuthError(p.base.providerName, attempt.session, reason) + } + if attempt.kind == "api_key" { + p.base.markAPIKeyFailure(reason) + } + break + } + p.base.markAttemptSuccess(attempt) + return body, status, ctype, nil } - p.base.markAttemptSuccess(attempt) - return body, status, ctype, nil } } return lastBody, lastStatus, lastType, nil @@ -163,14 +230,24 @@ func (p *AntigravityProvider) endpoint(baseURL string, stream bool) string { return base + path } +func (p *AntigravityProvider) countTokensEndpoint(baseURL string) string { + base := normalizeAPIBase(baseURL) + if base == "" { + base = antigravityDailyBaseURL + } + return base + "/" + defaultAntigravityAPIVersion + ":countTokens" +} + func (p *AntigravityProvider) baseURLs() []string { if p == nil || p.base == nil { return []string{antigravityDailyBaseURL} } - if custom := normalizeAPIBase(p.base.apiBase); custom != "" && !strings.Contains(strings.ToLower(custom), "api.openai.com") { + if custom := normalizeAPIBase(p.base.apiBase); custom != "" && + !strings.Contains(strings.ToLower(custom), "api.openai.com") && + custom != antigravityDailyBaseURL { return []string{custom} } - return []string{antigravityDailyBaseURL, antigravitySandboxBaseURL, defaultAntigravityAPIEndpoint} + return []string{antigravityDailyBaseURL, antigravitySandboxBaseURL, antigravityProdBaseURL, defaultAntigravityAPIEndpoint} } func (p *AntigravityProvider) buildRequestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, session *oauthSession, stream bool) map[string]any { @@ -409,6 +486,70 @@ func consumeAntigravityStream(resp *http.Response, onDelta func(string)) ([]byte return state.finalBody(), resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), nil } +func (p *AntigravityProvider) performCountTokensAttempt(ctx context.Context, endpoint string, payload map[string]any, attempt authAttempt) ([]byte, int, string, error) { + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData)) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to create request: %w", err) + } + req.Close = true + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", defaultAntigravityAPIUserAgent) + req.Header.Set("X-Goog-Api-Client", defaultAntigravityAPIClient) + req.Header.Set("Client-Metadata", defaultAntigravityClientMeta) + applyAttemptAuth(req, attempt) + client, err := p.base.httpClientForAttempt(attempt) + if err != nil { + return nil, 0, "", err + } + resp, err := client.Do(req) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), fmt.Errorf("failed to read response: %w", readErr) + } + return body, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), nil +} + +func antigravityShouldRetryNoCapacity(statusCode int, body []byte) bool { + if statusCode != http.StatusServiceUnavailable { + return false + } + return strings.Contains(strings.ToLower(string(body)), "no capacity available") +} + +func antigravityNoCapacityRetryDelay(attempt int) time.Duration { + if attempt < 0 { + attempt = 0 + } + delay := time.Duration(attempt+1) * 250 * time.Millisecond + if delay > 2*time.Second { + delay = 2 * time.Second + } + return delay +} + +func antigravityWait(ctx context.Context, wait time.Duration) error { + if wait <= 0 { + return nil + } + timer := time.NewTimer(wait) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + type antigravityStreamState struct { Text string ToolCalls []ToolCall diff --git a/pkg/providers/antigravity_provider_test.go b/pkg/providers/antigravity_provider_test.go index 8093ccd..8d91fe4 100644 --- a/pkg/providers/antigravity_provider_test.go +++ b/pkg/providers/antigravity_provider_test.go @@ -2,7 +2,11 @@ package providers import ( "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" "testing" + "time" ) func TestAntigravityBuildRequestBody(t *testing.T) { @@ -99,3 +103,66 @@ func TestParseAntigravityResponse(t *testing.T) { t.Fatalf("expected tool args, got %#v", args) } } + +func TestAntigravityProviderCountTokens(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1internal:countTokens" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"totalTokens":42}`)) + })) + defer server.Close() + + p := NewAntigravityProvider("antigravity", "token", server.URL, "gemini-2.5-pro", false, "api_key", 5*time.Second, nil) + usage, err := p.CountTokens(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "gemini-2.5-pro", nil) + if err != nil { + t.Fatalf("CountTokens error: %v", err) + } + if usage == nil || usage.PromptTokens != 42 || usage.TotalTokens != 42 { + t.Fatalf("usage = %#v, want 42", usage) + } +} + +func TestAntigravityProviderRetriesNoCapacity(t *testing.T) { + var hits int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1internal:generateContent" { + http.NotFound(w, r) + return + } + if atomic.AddInt32(&hits, 1) == 1 { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write([]byte(`{"error":{"message":"no capacity available"}}`)) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"ok"}]}}]}}`)) + })) + defer server.Close() + + p := NewAntigravityProvider("antigravity", "token", server.URL, "gemini-2.5-pro", false, "api_key", 5*time.Second, nil) + resp, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "gemini-2.5-pro", nil) + if err != nil { + t.Fatalf("Chat error: %v", err) + } + if resp.Content != "ok" { + t.Fatalf("content = %q, want ok", resp.Content) + } + if got := atomic.LoadInt32(&hits); got != 2 { + t.Fatalf("hits = %d, want 2", got) + } +} + +func TestAntigravityBaseURLsIncludeProdFallback(t *testing.T) { + p := NewAntigravityProvider("antigravity", "", "", "gemini-2.5-pro", false, "oauth", 0, nil) + got := p.baseURLs() + if len(got) < 3 { + t.Fatalf("baseURLs = %#v", got) + } + if got[0] != antigravityDailyBaseURL || got[1] != antigravitySandboxBaseURL || got[2] != antigravityProdBaseURL { + t.Fatalf("unexpected fallback order: %#v", got) + } +} diff --git a/pkg/providers/gemini_cli_provider.go b/pkg/providers/gemini_cli_provider.go new file mode 100644 index 0000000..619af46 --- /dev/null +++ b/pkg/providers/gemini_cli_provider.go @@ -0,0 +1,331 @@ +package providers + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "regexp" + "strconv" + "strings" + "time" +) + +const ( + geminiCLIBaseURL = "https://cloudcode-pa.googleapis.com" + geminiCLIVersion = "v1internal" + geminiCLIDefaultAlt = "sse" + geminiCLIApiClient = "genai-cli/0 gl-go/1.0" +) + +type GeminiCLIProvider struct { + base *HTTPProvider +} + +func NewGeminiCLIProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *GeminiCLIProvider { + normalizedBase := normalizeAPIBase(apiBase) + if normalizedBase == "" { + normalizedBase = geminiCLIBaseURL + } + return &GeminiCLIProvider{ + base: NewHTTPProvider(providerName, apiKey, normalizedBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth), + } +} + +func (p *GeminiCLIProvider) GetDefaultModel() string { + if p == nil || p.base == nil { + return "" + } + return p.base.GetDefaultModel() +} + +func (p *GeminiCLIProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, false, nil) + if err != nil { + return nil, err + } + if status != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body)) + } + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body)) + } + return parseGeminiResponse(body) +} + +func (p *GeminiCLIProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) { + body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, true, onDelta) + if err != nil { + return nil, err + } + if status != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body)) + } + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body)) + } + return parseGeminiResponse(body) +} + +func (p *GeminiCLIProvider) CountTokens(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*UsageInfo, error) { + if p == nil || p.base == nil { + return nil, fmt.Errorf("provider not configured") + } + attempts, err := p.base.authAttempts(ctx) + if err != nil { + return nil, err + } + var lastBody []byte + var lastStatus int + var lastType string + for _, attempt := range attempts { + requestBody := p.buildRequestBody(messages, nil, model, options, false, attempt.session) + delete(requestBody, "project") + delete(requestBody, "model") + request := mapFromAny(requestBody["request"]) + delete(request, "safetySettings") + requestBody["request"] = request + body, status, ctype, reqErr := p.performAttempt(ctx, p.endpoint("countTokens", false), requestBody, attempt, false, nil) + if reqErr != nil { + return nil, reqErr + } + lastBody, lastStatus, lastType = body, status, ctype + reason, retry := classifyOAuthFailure(status, body) + if retry { + applyAttemptFailure(p.base, attempt, reason, geminiRetryAfter(body)) + continue + } + if status != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body)) + } + var payload struct { + TotalTokens int `json:"totalTokens"` + } + if err := json.Unmarshal(body, &payload); err != nil { + return nil, fmt.Errorf("invalid countTokens response: %w", err) + } + p.base.markAttemptSuccess(attempt) + return &UsageInfo{PromptTokens: payload.TotalTokens, TotalTokens: payload.TotalTokens}, nil + } + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", lastStatus, lastType, previewResponseBody(lastBody)) +} + +func (p *GeminiCLIProvider) doRequest(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool, onDelta func(string)) ([]byte, int, string, error) { + attempts, err := p.base.authAttempts(ctx) + if err != nil { + return nil, 0, "", err + } + var lastBody []byte + var lastStatus int + var lastType string + action := "generateContent" + if stream { + action = "streamGenerateContent" + } + for _, attempt := range attempts { + requestBody := p.buildRequestBody(messages, tools, model, options, stream, attempt.session) + body, status, ctype, reqErr := p.performAttempt(ctx, p.endpoint(action, stream), requestBody, attempt, stream, onDelta) + if reqErr != nil { + return nil, 0, "", reqErr + } + lastBody, lastStatus, lastType = body, status, ctype + reason, retry := classifyOAuthFailure(status, body) + if retry { + applyAttemptFailure(p.base, attempt, reason, geminiRetryAfter(body)) + continue + } + p.base.markAttemptSuccess(attempt) + return body, status, ctype, nil + } + return lastBody, lastStatus, lastType, nil +} + +func (p *GeminiCLIProvider) buildRequestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool, session *oauthSession) map[string]any { + request := map[string]any{ + "request": p.buildInnerRequest(messages, tools, model, options, stream), + "model": strings.TrimSpace(qwenBaseModel(model)), + } + if projectID := geminiCLIProjectID(options, session); projectID != "" { + request["project"] = projectID + } + return request +} + +func (p *GeminiCLIProvider) buildInnerRequest(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool) map[string]any { + request := NewGeminiProvider(p.base.providerName, p.base.apiKey, p.base.apiBase, p.base.defaultModel, p.base.supportsResponsesCompact, p.base.authMode, p.base.timeout, p.base.oauth). + buildRequestBody(messages, tools, model, options, stream) + if _, ok := request["safetySettings"]; !ok { + request["safetySettings"] = []map[string]any{} + } + return request +} + +func (p *GeminiCLIProvider) endpoint(action string, stream bool) string { + base := normalizeAPIBase(p.base.apiBase) + if base == "" { + base = geminiCLIBaseURL + } + url := fmt.Sprintf("%s/%s:%s", base, geminiCLIVersion, action) + if stream { + return url + "?alt=" + geminiCLIDefaultAlt + } + return url +} + +func (p *GeminiCLIProvider) performAttempt(ctx context.Context, endpoint string, payload map[string]any, attempt authAttempt, stream bool, onDelta func(string)) ([]byte, int, string, error) { + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData)) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + if stream { + req.Header.Set("Accept", "text/event-stream") + } else { + req.Header.Set("Accept", "application/json") + } + if err := applyGeminiCLIAttemptAuth(req, attempt); err != nil { + return nil, 0, "", err + } + applyGeminiCLIHeaders(req, strings.TrimSpace(asString(payload["model"]))) + client, err := p.base.httpClientForAttempt(attempt) + if err != nil { + return nil, 0, "", err + } + resp, err := client.Do(req) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + ctype := strings.TrimSpace(resp.Header.Get("Content-Type")) + if stream && strings.Contains(strings.ToLower(ctype), "text/event-stream") { + return consumeGeminiCLIStream(resp, onDelta) + } + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, resp.StatusCode, ctype, fmt.Errorf("failed to read response: %w", readErr) + } + return body, resp.StatusCode, ctype, nil +} + +func applyGeminiCLIAttemptAuth(req *http.Request, attempt authAttempt) error { + if req == nil { + return nil + } + token := strings.TrimSpace(attempt.token) + if attempt.session != nil { + token = firstNonEmpty(strings.TrimSpace(attempt.session.AccessToken), token, asString(attempt.session.Token["access_token"])) + } + if token == "" { + return fmt.Errorf("missing access token for gemini-cli") + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Del("x-goog-api-key") + return nil +} + +func consumeGeminiCLIStream(resp *http.Response, onDelta func(string)) ([]byte, int, string, error) { + if onDelta == nil { + onDelta = func(string) {} + } + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) + var dataLines []string + state := &antigravityStreamState{} + for scanner.Scan() { + line := scanner.Text() + if strings.TrimSpace(line) == "" { + if len(dataLines) > 0 { + payload := strings.Join(dataLines, "\n") + dataLines = dataLines[:0] + if strings.TrimSpace(payload) != "" && strings.TrimSpace(payload) != "[DONE]" { + if delta := state.consume([]byte(payload)); delta != "" { + onDelta(delta) + } + } + } + continue + } + if strings.HasPrefix(line, "data:") { + dataLines = append(dataLines, strings.TrimSpace(strings.TrimPrefix(line, "data:"))) + } + } + if err := scanner.Err(); err != nil { + return nil, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), fmt.Errorf("failed to read stream: %w", err) + } + return state.finalBody(), resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), nil +} + +func geminiCLIProjectID(options map[string]interface{}, session *oauthSession) string { + if value, ok := stringOption(options, "gemini_project_id"); ok { + return value + } + if value, ok := stringOption(options, "project_id"); ok { + return value + } + if session == nil { + return "" + } + return firstNonEmpty(strings.TrimSpace(session.ProjectID), asString(session.Token["project_id"]), asString(session.Token["projectId"]), asString(session.Token["project"])) +} + +func applyGeminiCLIHeaders(req *http.Request, model string) { + if req == nil { + return + } + if strings.TrimSpace(model) == "" { + model = "unknown" + } + req.Header.Set("User-Agent", "GeminiCLI/"+model) + req.Header.Set("X-Goog-Api-Client", geminiCLIApiClient) +} + +func geminiRetryAfter(body []byte) *time.Duration { + if len(body) == 0 { + return nil + } + var root map[string]any + if err := json.Unmarshal(body, &root); err != nil { + return retryDelayFromMessage(string(body)) + } + errRoot := mapFromAny(root["error"]) + details, _ := errRoot["details"].([]any) + for _, raw := range details { + detail := mapFromAny(raw) + if asString(detail["@type"]) == "type.googleapis.com/google.rpc.RetryInfo" { + if d, err := time.ParseDuration(strings.TrimSpace(asString(detail["retryDelay"]))); err == nil { + return &d + } + } + } + for _, raw := range details { + detail := mapFromAny(raw) + if asString(detail["@type"]) == "type.googleapis.com/google.rpc.ErrorInfo" { + metadata := mapFromAny(detail["metadata"]) + if d, err := time.ParseDuration(strings.TrimSpace(asString(metadata["quotaResetDelay"]))); err == nil { + return &d + } + } + } + return retryDelayFromMessage(asString(errRoot["message"])) +} + +func retryDelayFromMessage(message string) *time.Duration { + re := regexp.MustCompile(`after\s+(\d+)s\.?`) + matches := re.FindStringSubmatch(strings.TrimSpace(message)) + if len(matches) < 2 { + return nil + } + seconds, err := strconv.Atoi(matches[1]) + if err != nil { + return nil + } + d := time.Duration(seconds) * time.Second + return &d +} diff --git a/pkg/providers/gemini_cli_provider_test.go b/pkg/providers/gemini_cli_provider_test.go new file mode 100644 index 0000000..d7eb50a --- /dev/null +++ b/pkg/providers/gemini_cli_provider_test.go @@ -0,0 +1,119 @@ +package providers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestGeminiCLIProviderBuildRequestBodyWrapsEnvelope(t *testing.T) { + p := NewGeminiCLIProvider("gemini-cli", "", "", "gemini-2.5-pro", false, "oauth", 5*time.Second, nil) + body := p.buildRequestBody([]Message{{Role: "user", Content: "hello"}}, nil, "gemini-2.5-pro", nil, false, &oauthSession{ProjectID: "demo-project"}) + + if got := asString(body["project"]); got != "demo-project" { + t.Fatalf("project = %q, want demo-project", got) + } + if got := asString(body["model"]); got != "gemini-2.5-pro" { + t.Fatalf("model = %q, want gemini-2.5-pro", got) + } + request := mapFromAny(body["request"]) + if len(request) == 0 { + t.Fatalf("request envelope missing: %#v", body) + } + if _, ok := request["safetySettings"]; !ok { + t.Fatalf("expected safetySettings in request: %#v", request) + } +} + +func TestGeminiCLIProviderChatUsesCloudCodeEndpoint(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1internal:generateContent" { + http.NotFound(w, r) + return + } + if got := r.Header.Get("Authorization"); got != "Bearer token" { + t.Fatalf("authorization = %q", got) + } + if got := r.Header.Get("X-Goog-Api-Client"); got != geminiCLIApiClient { + t.Fatalf("x-goog-api-client = %q", got) + } + if got := r.Header.Get("User-Agent"); got != "GeminiCLI/gemini-2.5-pro" { + t.Fatalf("user-agent = %q", got) + } + var payload map[string]any + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + t.Fatalf("decode request: %v", err) + } + if got := asString(payload["project"]); got != "demo-project" { + t.Fatalf("project = %q, want demo-project", got) + } + if got := asString(payload["model"]); got != "gemini-2.5-pro" { + t.Fatalf("model = %q, want gemini-2.5-pro", got) + } + if len(mapFromAny(payload["request"])) == 0 { + t.Fatalf("request envelope missing: %#v", payload) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"candidates":[{"content":{"parts":[{"text":"ok"}]},"finishReason":"STOP"}],"usageMetadata":{"totalTokenCount":2}}`)) + })) + defer server.Close() + + p := NewGeminiCLIProvider("gemini-cli", "token", server.URL, "gemini-2.5-pro", false, "api_key", 5*time.Second, nil) + resp, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "gemini-2.5-pro", map[string]any{"project_id": "demo-project"}) + if err != nil { + t.Fatalf("Chat error: %v", err) + } + if resp.Content != "ok" { + t.Fatalf("content = %q, want ok", resp.Content) + } +} + +func TestGeminiCLIProviderCountTokensRemovesProjectAndModel(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1internal:countTokens" { + http.NotFound(w, r) + return + } + var payload map[string]any + if err := json.NewDecoder(r.Body).Decode(&payload); err != nil { + t.Fatalf("decode request: %v", err) + } + if _, ok := payload["project"]; ok { + t.Fatalf("project should be removed for countTokens: %#v", payload) + } + if _, ok := payload["model"]; ok { + t.Fatalf("model should be removed for countTokens: %#v", payload) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"totalTokens":42}`)) + })) + defer server.Close() + + p := NewGeminiCLIProvider("gemini-cli", "token", server.URL, "gemini-2.5-pro", false, "api_key", 5*time.Second, nil) + usage, err := p.CountTokens(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "gemini-2.5-pro", map[string]any{"project_id": "demo-project"}) + if err != nil { + t.Fatalf("CountTokens error: %v", err) + } + if usage == nil || usage.TotalTokens != 42 { + t.Fatalf("usage = %#v, want 42", usage) + } +} + +func TestGeminiRetryAfterParsesGoogleRetryInfo(t *testing.T) { + retryAfter := geminiRetryAfter([]byte(`{ + "error": { + "message": "rate limited", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.RetryInfo", + "retryDelay": "1.5s" + } + ] + } + }`)) + if retryAfter == nil || *retryAfter != 1500*time.Millisecond { + t.Fatalf("retryAfter = %#v, want 1.5s", retryAfter) + } +} diff --git a/pkg/providers/gemini_provider.go b/pkg/providers/gemini_provider.go new file mode 100644 index 0000000..5102872 --- /dev/null +++ b/pkg/providers/gemini_provider.go @@ -0,0 +1,677 @@ +package providers + +import ( + "bufio" + "bytes" + "encoding/base64" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" +) + +const ( + geminiBaseURL = "https://generativelanguage.googleapis.com" + geminiAPIVersion = "v1beta" + geminiImagePreviewModel = "gemini-2.5-flash-image-preview" +) + +var geminiWhitePNGBase64 = base64.StdEncoding.EncodeToString([]byte{ + 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, + 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, + 0x08, 0x06, 0x00, 0x00, 0x00, 0x1F, 0x15, 0xC4, + 0x89, 0x00, 0x00, 0x00, 0x0D, 0x49, 0x44, 0x41, + 0x54, 0x78, 0x9C, 0x63, 0xF8, 0xFF, 0xFF, 0x3F, + 0x00, 0x05, 0xFE, 0x02, 0xFE, 0xDC, 0xCC, 0x59, + 0xE7, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, + 0x44, 0xAE, 0x42, 0x60, 0x82, +}) + +type GeminiProvider struct { + base *HTTPProvider +} + +func NewGeminiProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *GeminiProvider { + normalizedBase := normalizeAPIBase(apiBase) + if normalizedBase == "" { + normalizedBase = geminiBaseURL + } + return &GeminiProvider{ + base: NewHTTPProvider(providerName, apiKey, normalizedBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth), + } +} + +func (p *GeminiProvider) GetDefaultModel() string { + if p == nil || p.base == nil { + return "" + } + return p.base.GetDefaultModel() +} + +func (p *GeminiProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, false, nil) + if err != nil { + return nil, err + } + if status != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body)) + } + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body)) + } + return parseGeminiResponse(body) +} + +func (p *GeminiProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) { + body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, true, onDelta) + if err != nil { + return nil, err + } + if status != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body)) + } + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body)) + } + return parseGeminiResponse(body) +} + +func (p *GeminiProvider) CountTokens(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*UsageInfo, error) { + if p == nil || p.base == nil { + return nil, fmt.Errorf("provider not configured") + } + attempts, err := p.base.authAttempts(ctx) + if err != nil { + return nil, err + } + var lastBody []byte + var lastStatus int + var lastType string + for _, attempt := range attempts { + requestBody := p.buildRequestBody(messages, nil, model, options, false) + delete(requestBody, "tools") + delete(requestBody, "toolConfig") + delete(requestBody, "generationConfig") + endpoint := p.endpoint(attempt, model, "countTokens", false) + body, status, ctype, reqErr := p.performAttempt(ctx, endpoint, requestBody, attempt, false, nil) + if reqErr != nil { + return nil, reqErr + } + lastBody, lastStatus, lastType = body, status, ctype + reason, retry := classifyOAuthFailure(status, body) + if retry { + applyAttemptFailure(p.base, attempt, reason, nil) + continue + } + if status != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body)) + } + var payload struct { + TotalTokens int `json:"totalTokens"` + } + if err := json.Unmarshal(body, &payload); err != nil { + return nil, fmt.Errorf("invalid countTokens response: %w", err) + } + p.base.markAttemptSuccess(attempt) + return &UsageInfo{PromptTokens: payload.TotalTokens, TotalTokens: payload.TotalTokens}, nil + } + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", lastStatus, lastType, previewResponseBody(lastBody)) +} + +func (p *GeminiProvider) doRequest(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool, onDelta func(string)) ([]byte, int, string, error) { + if p == nil || p.base == nil { + return nil, 0, "", fmt.Errorf("provider not configured") + } + attempts, err := p.base.authAttempts(ctx) + if err != nil { + return nil, 0, "", err + } + var lastBody []byte + var lastStatus int + var lastType string + for _, attempt := range attempts { + requestBody := p.buildRequestBody(messages, tools, model, options, stream) + endpoint := p.endpoint(attempt, model, "generateContent", stream) + body, status, ctype, reqErr := p.performAttempt(ctx, endpoint, requestBody, attempt, stream, onDelta) + if reqErr != nil { + return nil, 0, "", reqErr + } + lastBody, lastStatus, lastType = body, status, ctype + reason, retry := classifyOAuthFailure(status, body) + if retry { + applyAttemptFailure(p.base, attempt, reason, nil) + continue + } + p.base.markAttemptSuccess(attempt) + return body, status, ctype, nil + } + return lastBody, lastStatus, lastType, nil +} + +func (p *GeminiProvider) performAttempt(ctx context.Context, endpoint string, payload map[string]any, attempt authAttempt, stream bool, onDelta func(string)) ([]byte, int, string, error) { + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(jsonData)) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + if stream { + req.Header.Set("Accept", "text/event-stream") + } else { + req.Header.Set("Accept", "application/json") + } + applyGeminiAttemptAuth(req, attempt) + client, err := p.base.httpClientForAttempt(attempt) + if err != nil { + return nil, 0, "", err + } + resp, err := client.Do(req) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + ctype := strings.TrimSpace(resp.Header.Get("Content-Type")) + if stream && strings.Contains(strings.ToLower(ctype), "text/event-stream") { + return consumeGeminiStream(resp, onDelta) + } + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return nil, resp.StatusCode, ctype, fmt.Errorf("failed to read response: %w", readErr) + } + return body, resp.StatusCode, ctype, nil +} + +func (p *GeminiProvider) endpoint(attempt authAttempt, model, action string, stream bool) string { + base := geminiBaseURLForAttempt(p.base, attempt) + baseModel := strings.TrimSpace(qwenBaseModel(model)) + if stream { + return fmt.Sprintf("%s/%s/models/%s:streamGenerateContent?alt=sse", base, geminiAPIVersion, baseModel) + } + return fmt.Sprintf("%s/%s/models/%s:%s", base, geminiAPIVersion, baseModel, action) +} + +func (p *GeminiProvider) buildRequestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool) map[string]any { + request := map[string]any{} + systemParts := make([]map[string]any, 0) + contents := make([]map[string]any, 0, len(messages)) + callNames := map[string]string{} + for _, msg := range messages { + role := strings.ToLower(strings.TrimSpace(msg.Role)) + switch role { + case "system", "developer": + if text := antigravityMessageText(msg); text != "" { + systemParts = append(systemParts, map[string]any{"text": text}) + } + case "user": + if parts := geminiTextParts(msg); len(parts) > 0 { + contents = append(contents, map[string]any{"role": "user", "parts": parts}) + } + case "assistant": + parts := geminiAssistantParts(msg) + for _, tc := range msg.ToolCalls { + name := strings.TrimSpace(tc.Name) + if tc.Function != nil && strings.TrimSpace(tc.Function.Name) != "" { + name = strings.TrimSpace(tc.Function.Name) + } + if name != "" && strings.TrimSpace(tc.ID) != "" { + callNames[strings.TrimSpace(tc.ID)] = name + } + } + if len(parts) > 0 { + contents = append(contents, map[string]any{"role": "model", "parts": parts}) + } + case "tool": + if part := antigravityToolResponsePart(msg, callNames); part != nil { + contents = append(contents, map[string]any{"role": "function", "parts": []map[string]any{part}}) + } + } + } + if len(systemParts) > 0 { + request["systemInstruction"] = map[string]any{"parts": systemParts} + } + if len(contents) > 0 { + request["contents"] = contents + } + if gen := antigravityGenerationConfig(options); len(gen) > 0 { + request["generationConfig"] = gen + } + if extra, ok := mapOption(options, "gemini_generation_config"); ok && len(extra) > 0 { + gen := mapFromAny(request["generationConfig"]) + if gen == nil { + gen = map[string]any{} + } + for k, v := range extra { + gen[k] = v + } + request["generationConfig"] = gen + } + if toolDecls := antigravityToolDeclarations(tools); len(toolDecls) > 0 { + request["tools"] = []map[string]any{{"function_declarations": toolDecls}} + request["toolConfig"] = map[string]any{ + "functionCallingConfig": map[string]any{"mode": "AUTO"}, + } + } + applyGeminiThinkingSuffix(request, model) + return fixGeminiImageAspectRatio(strings.TrimSpace(qwenBaseModel(model)), request) +} + +func applyGeminiThinkingSuffix(request map[string]any, model string) { + suffix := qwenModelSuffix(model) + if strings.TrimSpace(suffix) == "" { + return + } + baseModel := strings.TrimSpace(qwenBaseModel(model)) + gen := mapFromAny(request["generationConfig"]) + if gen == nil { + gen = map[string]any{} + } + thinkingConfig := mapFromAny(gen["thinkingConfig"]) + if thinkingConfig == nil { + thinkingConfig = map[string]any{} + } + delete(thinkingConfig, "thinkingBudget") + delete(thinkingConfig, "thinking_budget") + delete(thinkingConfig, "thinkingLevel") + delete(thinkingConfig, "thinking_level") + delete(thinkingConfig, "include_thoughts") + + lower := strings.ToLower(strings.TrimSpace(suffix)) + switch { + case lower == "auto" || lower == "-1": + thinkingConfig["thinkingBudget"] = -1 + thinkingConfig["includeThoughts"] = true + case lower == "none": + if geminiUsesThinkingLevels(baseModel) { + thinkingConfig["thinkingLevel"] = "low" + } else { + thinkingConfig["thinkingBudget"] = 128 + } + thinkingConfig["includeThoughts"] = false + case isGeminiThinkingLevel(lower): + if geminiUsesThinkingLevels(baseModel) { + thinkingConfig["thinkingLevel"] = normalizeGeminiThinkingLevel(lower) + thinkingConfig["includeThoughts"] = true + } else { + thinkingConfig["thinkingBudget"] = geminiThinkingBudgetForLevel(lower) + thinkingConfig["includeThoughts"] = true + } + default: + if budget, err := strconv.Atoi(lower); err == nil { + if budget < 0 { + thinkingConfig["thinkingBudget"] = -1 + thinkingConfig["includeThoughts"] = true + } else if budget == 0 { + thinkingConfig["thinkingBudget"] = 128 + thinkingConfig["includeThoughts"] = false + } else { + thinkingConfig["thinkingBudget"] = budget + thinkingConfig["includeThoughts"] = true + } + } + } + if len(thinkingConfig) == 0 { + return + } + gen["thinkingConfig"] = thinkingConfig + request["generationConfig"] = gen +} + +func geminiUsesThinkingLevels(model string) bool { + trimmed := strings.ToLower(strings.TrimSpace(model)) + return strings.Contains(trimmed, "gemini-3") +} + +func isGeminiThinkingLevel(value string) bool { + switch strings.ToLower(strings.TrimSpace(value)) { + case "minimal", "low", "medium", "high", "xhigh", "max": + return true + default: + return false + } +} + +func normalizeGeminiThinkingLevel(value string) string { + switch strings.ToLower(strings.TrimSpace(value)) { + case "xhigh", "max": + return "high" + case "minimal": + return "low" + default: + return strings.ToLower(strings.TrimSpace(value)) + } +} + +func geminiThinkingBudgetForLevel(value string) int { + switch strings.ToLower(strings.TrimSpace(value)) { + case "minimal": + return 128 + case "low": + return 1024 + case "medium": + return 8192 + case "high": + return 24576 + case "xhigh", "max": + return 32768 + default: + return 8192 + } +} + +func geminiTextParts(msg Message) []map[string]any { + if len(msg.ContentParts) == 0 { + if text := strings.TrimSpace(msg.Content); text != "" { + return []map[string]any{{"text": text}} + } + return nil + } + parts := make([]map[string]any, 0, len(msg.ContentParts)) + for _, part := range msg.ContentParts { + switch strings.ToLower(strings.TrimSpace(part.Type)) { + case "", "text", "input_text": + if text := strings.TrimSpace(part.Text); text != "" { + parts = append(parts, map[string]any{"text": text}) + } + case "input_image", "image_url": + if inline := geminiInlineDataPart(firstNonEmpty(part.FileData, part.ImageURL), part.MIMEType); inline != nil { + parts = append(parts, inline) + continue + } + if url := strings.TrimSpace(firstNonEmpty(part.ImageURL, part.FileURL)); url != "" { + parts = append(parts, map[string]any{ + "fileData": map[string]any{ + "mimeType": firstNonEmpty(strings.TrimSpace(part.MIMEType), "image/png"), + "fileUri": url, + }, + }) + } + case "input_file", "file": + if inline := geminiInlineDataPart(firstNonEmpty(part.FileData, part.FileURL), part.MIMEType); inline != nil { + parts = append(parts, inline) + continue + } + if url := strings.TrimSpace(part.FileURL); url != "" { + parts = append(parts, map[string]any{ + "fileData": map[string]any{ + "mimeType": firstNonEmpty(strings.TrimSpace(part.MIMEType), "application/octet-stream"), + "fileUri": url, + }, + }) + } + } + } + if len(parts) == 0 && strings.TrimSpace(msg.Content) != "" { + return []map[string]any{{"text": strings.TrimSpace(msg.Content)}} + } + return parts +} + +func geminiAssistantParts(msg Message) []map[string]any { + parts := geminiTextParts(msg) + for _, tc := range msg.ToolCalls { + name := strings.TrimSpace(tc.Name) + args := map[string]any{} + if tc.Function != nil { + if strings.TrimSpace(tc.Function.Name) != "" { + name = strings.TrimSpace(tc.Function.Name) + } + if strings.TrimSpace(tc.Function.Arguments) != "" { + _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) + } + } + if len(args) == 0 && len(tc.Arguments) > 0 { + args = tc.Arguments + } + if name == "" { + continue + } + part := map[string]any{ + "functionCall": map[string]any{ + "name": name, + "args": args, + }, + } + if strings.TrimSpace(tc.ID) != "" { + part["functionCall"].(map[string]any)["id"] = strings.TrimSpace(tc.ID) + } + parts = append(parts, part) + } + return parts +} + +func consumeGeminiStream(resp *http.Response, onDelta func(string)) ([]byte, int, string, error) { + if onDelta == nil { + onDelta = func(string) {} + } + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 0, 64*1024), 2*1024*1024) + var dataLines []string + state := &antigravityStreamState{} + for scanner.Scan() { + line := scanner.Text() + if strings.TrimSpace(line) == "" { + if len(dataLines) > 0 { + payload := strings.Join(dataLines, "\n") + dataLines = dataLines[:0] + if strings.TrimSpace(payload) != "" && strings.TrimSpace(payload) != "[DONE]" { + filtered := filterGeminiSSEUsageMetadata([]byte(payload)) + if delta := state.consume(filtered); delta != "" { + onDelta(delta) + } + } + } + continue + } + if strings.HasPrefix(line, "data:") { + dataLines = append(dataLines, strings.TrimSpace(strings.TrimPrefix(line, "data:"))) + } + } + if err := scanner.Err(); err != nil { + return nil, resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), fmt.Errorf("failed to read stream: %w", err) + } + return state.finalBody(), resp.StatusCode, strings.TrimSpace(resp.Header.Get("Content-Type")), nil +} + +func parseGeminiResponse(body []byte) (*LLMResponse, error) { + return parseAntigravityResponse(body) +} + +func geminiBaseURLForAttempt(base *HTTPProvider, attempt authAttempt) string { + if attempt.session != nil { + if raw := strings.TrimSpace(attempt.session.ResourceURL); raw != "" { + return normalizeGeminiBaseURL(raw) + } + if attempt.session.Token != nil { + if raw := strings.TrimSpace(asString(attempt.session.Token["base_url"])); raw != "" { + return normalizeGeminiBaseURL(raw) + } + if raw := strings.TrimSpace(asString(attempt.session.Token["resource_url"])); raw != "" { + return normalizeGeminiBaseURL(raw) + } + } + } + if base != nil && strings.TrimSpace(base.apiBase) != "" && !strings.Contains(strings.ToLower(base.apiBase), "api.openai.com") { + return normalizeGeminiBaseURL(base.apiBase) + } + return geminiBaseURL +} + +func normalizeGeminiBaseURL(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return geminiBaseURL + } + if !strings.Contains(trimmed, "://") { + trimmed = "https://" + trimmed + } + trimmed = normalizeAPIBase(trimmed) + if strings.Contains(trimmed, "/models/") { + if idx := strings.Index(trimmed, "/models/"); idx > 0 { + trimmed = trimmed[:idx] + } + } + if strings.HasSuffix(trimmed, "/models") { + trimmed = strings.TrimSuffix(trimmed, "/models") + } + if strings.HasSuffix(trimmed, "/"+geminiAPIVersion) { + trimmed = strings.TrimSuffix(trimmed, "/"+geminiAPIVersion) + } + return trimmed +} + +func geminiInlineDataPart(raw, mimeType string) map[string]any { + data, mt, ok := parseDataURL(raw) + if !ok { + return nil + } + return map[string]any{ + "inlineData": map[string]any{ + "mimeType": firstNonEmpty(strings.TrimSpace(mimeType), mt), + "data": data, + }, + } +} + +func parseDataURL(raw string) (data, mimeType string, ok bool) { + trimmed := strings.TrimSpace(raw) + if !strings.HasPrefix(trimmed, "data:") { + return "", "", false + } + comma := strings.Index(trimmed, ",") + if comma <= len("data:") { + return "", "", false + } + meta := trimmed[len("data:"):comma] + payload := trimmed[comma+1:] + if !strings.HasSuffix(strings.ToLower(meta), ";base64") { + return "", "", false + } + mimeType = strings.TrimSuffix(meta, ";base64") + if strings.TrimSpace(mimeType) == "" { + mimeType = "application/octet-stream" + } + return payload, mimeType, true +} + +func fixGeminiImageAspectRatio(modelName string, request map[string]any) map[string]any { + if strings.TrimSpace(modelName) != geminiImagePreviewModel || request == nil { + return request + } + generationConfig := mapFromAny(request["generationConfig"]) + if len(generationConfig) == 0 { + return request + } + imageConfig := mapFromAny(generationConfig["imageConfig"]) + aspectRatio := strings.TrimSpace(asString(imageConfig["aspectRatio"])) + if aspectRatio == "" { + return request + } + contents, _ := request["contents"].([]map[string]any) + hasInlineData := false + for _, content := range contents { + parts, _ := content["parts"].([]map[string]any) + for _, part := range parts { + if len(mapFromAny(part["inlineData"])) > 0 { + hasInlineData = true + break + } + } + if hasInlineData { + break + } + } + if !hasInlineData && len(contents) > 0 { + parts, _ := contents[0]["parts"].([]map[string]any) + prefixed := []map[string]any{ + {"text": "Based on the following requirements, create an image within the uploaded picture. The new content must completely cover the entire area of the original picture, maintaining its exact proportions, and no blank areas should appear."}, + {"inlineData": map[string]any{"mimeType": "image/png", "data": geminiWhitePNGBase64}}, + } + contents[0]["parts"] = append(prefixed, parts...) + request["contents"] = contents + generationConfig["responseModalities"] = []any{"IMAGE", "TEXT"} + } + delete(generationConfig, "imageConfig") + request["generationConfig"] = generationConfig + return request +} + +func filterGeminiSSEUsageMetadata(payload []byte) []byte { + if len(payload) == 0 { + return payload + } + var root map[string]any + if err := json.Unmarshal(bytes.TrimSpace(payload), &root); err != nil { + return payload + } + if geminiPayloadHasFinishReason(root) { + out, err := json.Marshal(root) + if err != nil { + return payload + } + return out + } + delete(root, "usageMetadata") + delete(root, "usage_metadata") + if response := mapFromAny(root["response"]); len(response) > 0 { + delete(response, "usageMetadata") + root["response"] = response + } + out, err := json.Marshal(root) + if err != nil { + return payload + } + return out +} + +func geminiPayloadHasFinishReason(root map[string]any) bool { + if candidateHasFinishReason(root["candidates"]) { + return true + } + if response := mapFromAny(root["response"]); len(response) > 0 { + return candidateHasFinishReason(response["candidates"]) + } + return false +} + +func candidateHasFinishReason(raw any) bool { + switch typed := raw.(type) { + case []any: + if len(typed) == 0 { + return false + } + candidate := mapFromAny(typed[0]) + return strings.TrimSpace(asString(candidate["finishReason"])) != "" + case []map[string]any: + if len(typed) == 0 { + return false + } + return strings.TrimSpace(asString(typed[0]["finishReason"])) != "" + default: + return false + } +} + +func applyGeminiAttemptAuth(req *http.Request, attempt authAttempt) { + if req == nil { + return + } + token := strings.TrimSpace(attempt.token) + if token == "" { + return + } + req.Header.Del("Authorization") + req.Header.Del("x-goog-api-key") + if attempt.kind == "api_key" { + req.Header.Set("x-goog-api-key", token) + return + } + req.Header.Set("Authorization", "Bearer "+token) +} diff --git a/pkg/providers/gemini_provider_test.go b/pkg/providers/gemini_provider_test.go new file mode 100644 index 0000000..5d1d72c --- /dev/null +++ b/pkg/providers/gemini_provider_test.go @@ -0,0 +1,294 @@ +package providers + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/YspCoder/clawgo/pkg/config" +) + +func TestGeminiBuildRequestBody(t *testing.T) { + p := NewGeminiProvider("gemini", "", "", "gemini-2.5-pro", false, "oauth", 5*time.Second, nil) + body := p.buildRequestBody([]Message{ + {Role: "system", Content: "You are helpful."}, + {Role: "user", Content: "hello"}, + { + Role: "assistant", + Content: "calling tool", + ToolCalls: []ToolCall{{ + ID: "call_1", + Name: "lookup", + Function: &FunctionCall{ + Name: "lookup", + Arguments: `{"q":"weather"}`, + }, + }}, + }, + {Role: "tool", ToolCallID: "call_1", Content: `{"ok":true}`}, + }, []ToolDefinition{{ + Type: "function", + Function: ToolFunctionDefinition{ + Name: "lookup", + Description: "Lookup data", + Parameters: map[string]interface{}{"type": "object"}, + }, + }}, "gemini-2.5-pro", map[string]interface{}{ + "max_tokens": 128, + "temperature": 0.3, + }, false) + + request := body + if system := asString(mapFromAny(request["systemInstruction"])["parts"].([]map[string]any)[0]["text"]); system != "You are helpful." { + t.Fatalf("expected system instruction, got %q", system) + } + if got := len(request["contents"].([]map[string]any)); got != 3 { + t.Fatalf("expected 3 content entries, got %d", got) + } + gen := mapFromAny(request["generationConfig"]) + if got := intValue(gen["maxOutputTokens"]); got != 128 { + t.Fatalf("expected maxOutputTokens, got %#v", gen["maxOutputTokens"]) + } +} + +func TestGeminiProviderCountTokens(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1beta/models/gemini-2.5-pro:countTokens" { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"totalTokens":42}`)) + })) + defer server.Close() + + p := NewGeminiProvider("gemini", "token", server.URL, "gemini-2.5-pro", false, "api_key", 5*time.Second, nil) + usage, err := p.CountTokens(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "gemini-2.5-pro", nil) + if err != nil { + t.Fatalf("CountTokens error: %v", err) + } + if usage == nil || usage.PromptTokens != 42 || usage.TotalTokens != 42 { + t.Fatalf("usage = %#v, want 42", usage) + } +} + +func TestGeminiProviderChat(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1beta/models/gemini-2.5-pro:generateContent" { + http.NotFound(w, r) + return + } + if got := r.Header.Get("x-goog-api-key"); got != "token" { + t.Fatalf("x-goog-api-key = %q", got) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"candidates":[{"content":{"parts":[{"text":"ok"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2}}`)) + })) + defer server.Close() + + p := NewGeminiProvider("gemini", "token", server.URL, "gemini-2.5-pro", false, "api_key", 5*time.Second, nil) + resp, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "gemini-2.5-pro", nil) + if err != nil { + t.Fatalf("Chat error: %v", err) + } + if resp.Content != "ok" { + t.Fatalf("content = %q, want ok", resp.Content) + } + if resp.Usage == nil || resp.Usage.TotalTokens != 2 { + t.Fatalf("usage = %#v", resp.Usage) + } +} + +func TestGeminiBuildRequestBodyFixesImageAspectRatio(t *testing.T) { + p := NewGeminiProvider("gemini", "", "", geminiImagePreviewModel, false, "api_key", 5*time.Second, nil) + body := p.buildRequestBody([]Message{ + {Role: "user", Content: "draw a cat"}, + }, nil, geminiImagePreviewModel, map[string]interface{}{ + "gemini_generation_config": map[string]interface{}{ + "imageConfig": map[string]interface{}{"aspectRatio": "1:1"}, + }, + }, false) + + gen := mapFromAny(body["generationConfig"]) + if _, ok := gen["imageConfig"]; ok { + t.Fatalf("expected imageConfig to be removed, got %#v", gen["imageConfig"]) + } + if got := len(gen["responseModalities"].([]any)); got != 2 { + t.Fatalf("responseModalities len = %d", got) + } + contents := body["contents"].([]map[string]any) + parts := contents[0]["parts"].([]map[string]any) + if len(parts) < 3 { + t.Fatalf("parts = %#v", parts) + } + if _, ok := mapFromAny(parts[1]["inlineData"])["data"]; !ok { + t.Fatalf("expected inlineData placeholder, got %#v", parts[1]) + } +} + +func TestGeminiBuildRequestBodyAppliesBudgetThinkingSuffix(t *testing.T) { + p := NewGeminiProvider("gemini", "", "", "gemini-2.5-pro", false, "api_key", 5*time.Second, nil) + + body := p.buildRequestBody([]Message{{Role: "user", Content: "hi"}}, nil, "gemini-2.5-pro(high)", nil, false) + gen := mapFromAny(body["generationConfig"]) + thinking := mapFromAny(gen["thinkingConfig"]) + if got := intValue(thinking["thinkingBudget"]); got != 24576 { + t.Fatalf("thinkingBudget = %v, want 24576", thinking["thinkingBudget"]) + } + if got := fmt.Sprintf("%v", thinking["includeThoughts"]); got != "true" { + t.Fatalf("includeThoughts = %v, want true", thinking["includeThoughts"]) + } + + body = p.buildRequestBody([]Message{{Role: "user", Content: "hi"}}, nil, "gemini-2.5-pro(none)", nil, false) + gen = mapFromAny(body["generationConfig"]) + thinking = mapFromAny(gen["thinkingConfig"]) + if got := intValue(thinking["thinkingBudget"]); got != 128 { + t.Fatalf("thinkingBudget = %v, want 128", thinking["thinkingBudget"]) + } + if got := fmt.Sprintf("%v", thinking["includeThoughts"]); got != "false" { + t.Fatalf("includeThoughts = %v, want false", thinking["includeThoughts"]) + } + + body = p.buildRequestBody([]Message{{Role: "user", Content: "hi"}}, nil, "gemini-2.5-pro(auto)", nil, false) + gen = mapFromAny(body["generationConfig"]) + thinking = mapFromAny(gen["thinkingConfig"]) + if got := intValue(thinking["thinkingBudget"]); got != -1 { + t.Fatalf("thinkingBudget = %v, want -1", thinking["thinkingBudget"]) + } + + body = p.buildRequestBody([]Message{{Role: "user", Content: "hi"}}, nil, "gemini-2.5-pro(8192)", nil, false) + gen = mapFromAny(body["generationConfig"]) + thinking = mapFromAny(gen["thinkingConfig"]) + if got := intValue(thinking["thinkingBudget"]); got != 8192 { + t.Fatalf("thinkingBudget = %v, want 8192", thinking["thinkingBudget"]) + } +} + +func TestGeminiBuildRequestBodyAppliesLevelThinkingSuffix(t *testing.T) { + p := NewGeminiProvider("gemini", "", "", "gemini-3-pro-preview", false, "api_key", 5*time.Second, nil) + + body := p.buildRequestBody([]Message{{Role: "user", Content: "hi"}}, nil, "gemini-3-pro-preview(high)", nil, false) + gen := mapFromAny(body["generationConfig"]) + thinking := mapFromAny(gen["thinkingConfig"]) + if got := asString(thinking["thinkingLevel"]); got != "high" { + t.Fatalf("thinkingLevel = %q, want high", got) + } + if got := fmt.Sprintf("%v", thinking["includeThoughts"]); got != "true" { + t.Fatalf("includeThoughts = %v, want true", thinking["includeThoughts"]) + } + + body = p.buildRequestBody([]Message{{Role: "user", Content: "hi"}}, nil, "gemini-3-pro-preview(xhigh)", nil, false) + gen = mapFromAny(body["generationConfig"]) + thinking = mapFromAny(gen["thinkingConfig"]) + if got := asString(thinking["thinkingLevel"]); got != "high" { + t.Fatalf("thinkingLevel = %q, want high", got) + } + + body = p.buildRequestBody([]Message{{Role: "user", Content: "hi"}}, nil, "gemini-3-pro-preview(none)", nil, false) + gen = mapFromAny(body["generationConfig"]) + thinking = mapFromAny(gen["thinkingConfig"]) + if got := asString(thinking["thinkingLevel"]); got != "low" { + t.Fatalf("thinkingLevel = %q, want low", got) + } + if got := fmt.Sprintf("%v", thinking["includeThoughts"]); got != "false" { + t.Fatalf("includeThoughts = %v, want false", thinking["includeThoughts"]) + } + + body = p.buildRequestBody([]Message{{Role: "user", Content: "hi"}}, nil, "gemini-3-pro-preview(auto)", nil, false) + gen = mapFromAny(body["generationConfig"]) + thinking = mapFromAny(gen["thinkingConfig"]) + if got := intValue(thinking["thinkingBudget"]); got != -1 { + t.Fatalf("thinkingBudget = %v, want -1", thinking["thinkingBudget"]) + } +} + +func TestFilterGeminiSSEUsageMetadataDropsNonTerminalUsage(t *testing.T) { + raw := []byte(`{"candidates":[{"content":{"parts":[{"text":"hello"}]}}],"usageMetadata":{"promptTokenCount":1}}`) + filtered := filterGeminiSSEUsageMetadata(raw) + var payload map[string]any + if err := json.Unmarshal(filtered, &payload); err != nil { + t.Fatalf("unmarshal filtered: %v", err) + } + if _, ok := payload["usageMetadata"]; ok { + t.Fatalf("expected usageMetadata to be removed: %#v", payload) + } +} + +func TestGeminiTextPartsSupportInlineDataAndFileURLs(t *testing.T) { + parts := geminiTextParts(Message{ + Role: "user", + ContentParts: []MessageContentPart{ + {Type: "input_image", ImageURL: "data:image/png;base64,AAAA", MIMEType: "image/png"}, + {Type: "input_file", FileURL: "https://example.com/doc.pdf", MIMEType: "application/pdf"}, + }, + }) + if len(parts) != 2 { + t.Fatalf("parts = %#v", parts) + } + inline := mapFromAny(parts[0]["inlineData"]) + if got := asString(inline["data"]); got != "AAAA" { + t.Fatalf("inline data = %q", got) + } + fileData := mapFromAny(parts[1]["fileData"]) + if got := asString(fileData["fileUri"]); got != "https://example.com/doc.pdf" { + t.Fatalf("file uri = %q", got) + } +} + +func TestGeminiBaseURLForAttemptUsesSessionResourceURL(t *testing.T) { + base := NewHTTPProvider("gemini", "token", geminiBaseURL, "gemini-2.5-pro", false, "oauth", 5*time.Second, nil) + got := geminiBaseURLForAttempt(base, authAttempt{ + kind: "oauth", + session: &oauthSession{ + ResourceURL: "https://generativelanguage.googleapis.com/v1beta/models", + }, + }) + if got != geminiBaseURL { + t.Fatalf("base url = %q, want %q", got, geminiBaseURL) + } +} + +func TestCreateProviderByNameRoutesGeminiCLIToGeminiProvider(t *testing.T) { + cfg := &config.Config{ + Models: config.ModelsConfig{ + Providers: map[string]config.ProviderConfig{ + "gemini-cli": { + APIBase: "", + APIKey: "token", + TimeoutSec: 30, + Models: []string{"gemini-2.5-pro"}, + }, + }, + }, + } + provider, err := CreateProviderByName(cfg, "gemini-cli") + if err != nil { + t.Fatalf("CreateProviderByName error: %v", err) + } + if _, ok := provider.(*GeminiCLIProvider); !ok { + t.Fatalf("provider = %T, want *GeminiCLIProvider", provider) + } +} + +func TestCreateProviderByNameRoutesAIStudioProviderViaGeminiTests(t *testing.T) { + cfg := &config.Config{ + Models: config.ModelsConfig{ + Providers: map[string]config.ProviderConfig{ + "aistudio": { + TimeoutSec: 30, + Models: []string{"gemini-2.5-pro"}, + }, + }, + }, + } + provider, err := CreateProviderByName(cfg, "aistudio") + if err != nil { + t.Fatalf("CreateProviderByName error: %v", err) + } + if _, ok := provider.(*AistudioProvider); !ok { + t.Fatalf("provider = %T, want *AistudioProvider", provider) + } +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 4140e76..5b0c803 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -943,11 +943,11 @@ func applyAttemptProviderHeaders(req *http.Request, attempt authAttempt, provide req.Header.Set("X-Stainless-Runtime-Version", "v22.17.0") req.Header.Set("Sec-Fetch-Mode", "cors") req.Header.Set("X-Stainless-Lang", "js") - req.Header.Set("X-Stainless-Arch", "arm64") + req.Header.Set("X-Stainless-Arch", qwenStainlessArch()) req.Header.Set("X-Stainless-Package-Version", "5.11.0") req.Header.Set("X-Dashscope-Cachecontrol", "enable") req.Header.Set("X-Stainless-Retry-Count", "0") - req.Header.Set("X-Stainless-Os", "MacOS") + req.Header.Set("X-Stainless-Os", qwenStainlessOS()) req.Header.Set("X-Dashscope-Authtype", "qwen-oauth") req.Header.Set("X-Stainless-Runtime", "node") if stream { @@ -962,13 +962,9 @@ func applyAttemptProviderHeaders(req *http.Request, attempt authAttempt, provide req.Header.Set("User-Agent", kimiCompatUserAgent) req.Header.Set("X-Msh-Platform", "kimi_cli") req.Header.Set("X-Msh-Version", "1.10.6") - req.Header.Set("X-Msh-Device-Name", "clawgo") - req.Header.Set("X-Msh-Device-Model", runtime.GOOS+" "+runtime.GOARCH) - if attempt.session != nil && strings.TrimSpace(attempt.session.DeviceID) != "" { - req.Header.Set("X-Msh-Device-Id", strings.TrimSpace(attempt.session.DeviceID)) - } else { - req.Header.Set("X-Msh-Device-Id", "clawgo-device") - } + req.Header.Set("X-Msh-Device-Name", kimiDeviceName()) + req.Header.Set("X-Msh-Device-Model", kimiDeviceModel()) + req.Header.Set("X-Msh-Device-Id", kimiDeviceID(attempt.session)) if stream { req.Header.Set("Accept", "text/event-stream") } else { @@ -1005,9 +1001,42 @@ func randomSessionID() string { return fmt.Sprintf("%x-%x-%x-%x-%x", buf[0:4], buf[4:6], buf[6:8], buf[8:10], buf[10:16]) } +func qwenStainlessArch() string { + switch runtime.GOARCH { + case "amd64": + return "x64" + case "386": + return "x86" + default: + return runtime.GOARCH + } +} + +func qwenStainlessOS() string { + switch runtime.GOOS { + case "darwin": + return "MacOS" + case "windows": + return "Windows" + case "linux": + return "Linux" + default: + if runtime.GOOS == "" { + return "" + } + return strings.ToUpper(runtime.GOOS[:1]) + runtime.GOOS[1:] + } +} + func (p *HTTPProvider) httpClientForAttempt(attempt authAttempt) (*http.Client, error) { if attempt.kind == "oauth" && attempt.session != nil && p.oauth != nil { - return p.oauth.httpClientForSession(attempt.session) + client, err := p.oauth.httpClientForSession(attempt.session) + if err != nil { + return nil, err + } + if client != nil { + return client, nil + } } return p.httpClient, nil } @@ -1733,7 +1762,11 @@ func GetProviderRuntimeSnapshot(cfg *config.Config) map[string]interface{} { "last_success": state.LastSuccess, } candidateOrder := state.CandidateOrder - if strings.EqualFold(strings.TrimSpace(pc.Auth), "oauth") || strings.EqualFold(strings.TrimSpace(pc.Auth), "hybrid") { + if strings.EqualFold(name, "aistudio") { + if accounts := listAIStudioRelayAccounts(); len(accounts) > 0 { + item["oauth_accounts"] = accounts + } + } else if strings.EqualFold(strings.TrimSpace(pc.Auth), "oauth") || strings.EqualFold(strings.TrimSpace(pc.Auth), "hybrid") { if mgr, err := NewOAuthLoginManager(pc, time.Duration(maxInt(pc.TimeoutSec, 90))*time.Second); err == nil { if accounts, err := mgr.ListAccounts(); err == nil { item["oauth_accounts"] = accounts @@ -2343,13 +2376,14 @@ func openAICompatMessages(messages []Message) []map[string]interface{} { out := make([]map[string]interface{}, 0, len(messages)) for _, msg := range messages { role := strings.ToLower(strings.TrimSpace(msg.Role)) + content := openAICompatMessageContent(msg) switch role { case "system": - out = append(out, map[string]interface{}{"role": "system", "content": msg.Content}) + out = append(out, map[string]interface{}{"role": "system", "content": content}) case "developer": - out = append(out, map[string]interface{}{"role": "user", "content": msg.Content}) + out = append(out, map[string]interface{}{"role": "user", "content": content}) case "assistant": - item := map[string]interface{}{"role": "assistant", "content": msg.Content} + item := map[string]interface{}{"role": "assistant", "content": content} if len(msg.ToolCalls) > 0 { toolCalls := make([]map[string]interface{}, 0, len(msg.ToolCalls)) for _, tc := range msg.ToolCalls { @@ -2381,15 +2415,66 @@ func openAICompatMessages(messages []Message) []map[string]interface{} { out = append(out, map[string]interface{}{ "role": "tool", "tool_call_id": msg.ToolCallID, - "content": msg.Content, + "content": content, }) default: - out = append(out, map[string]interface{}{"role": "user", "content": msg.Content}) + out = append(out, map[string]interface{}{"role": "user", "content": content}) } } return out } +func openAICompatMessageContent(msg Message) interface{} { + if len(msg.ContentParts) == 0 { + return msg.Content + } + parts := make([]map[string]interface{}, 0, len(msg.ContentParts)) + for _, part := range msg.ContentParts { + switch strings.ToLower(strings.TrimSpace(part.Type)) { + case "text", "input_text": + if strings.TrimSpace(part.Text) == "" { + continue + } + parts = append(parts, map[string]interface{}{ + "type": "text", + "text": part.Text, + }) + case "input_image", "image_url": + imageURL := strings.TrimSpace(part.ImageURL) + if imageURL == "" { + continue + } + payload := map[string]interface{}{ + "type": "image_url", + "image_url": map[string]interface{}{ + "url": imageURL, + }, + } + if detail := strings.TrimSpace(part.Detail); detail != "" { + payload["image_url"].(map[string]interface{})["detail"] = detail + } + parts = append(parts, payload) + default: + if strings.TrimSpace(part.Text) == "" { + continue + } + parts = append(parts, map[string]interface{}{ + "type": "text", + "text": part.Text, + }) + } + } + if len(parts) == 0 { + return msg.Content + } + if len(parts) == 1 && parts[0]["type"] == "text" && len(msg.ToolCalls) == 0 { + if text, _ := parts[0]["text"].(string); text != "" { + return text + } + } + return parts +} + func openAICompatTools(tools []ToolDefinition) []map[string]interface{} { out := make([]map[string]interface{}, 0, len(tools)) for _, tool := range tools { @@ -2615,7 +2700,12 @@ func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error) } ConfigureProviderRuntime(name, pc) oauthProvider := strings.ToLower(strings.TrimSpace(pc.OAuth.Provider)) - if pc.APIBase == "" && oauthProvider != defaultAntigravityOAuthProvider { + if pc.APIBase == "" && + oauthProvider != defaultAntigravityOAuthProvider && + oauthProvider != defaultGeminiOAuthProvider && + !strings.EqualFold(name, "gemini-cli") && + !strings.EqualFold(name, "aistudio") && + !strings.EqualFold(name, "vertex") { return nil, fmt.Errorf("no API base configured for provider %q", name) } if pc.TimeoutSec <= 0 { @@ -2635,6 +2725,18 @@ func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error) if oauthProvider == defaultAntigravityOAuthProvider { return NewAntigravityProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil } + if strings.EqualFold(name, "aistudio") { + return NewAistudioProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + } + if strings.EqualFold(name, "gemini-cli") { + return NewGeminiCLIProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + } + if oauthProvider == defaultGeminiOAuthProvider || strings.EqualFold(name, defaultGeminiOAuthProvider) || strings.EqualFold(name, "aistudio") { + return NewGeminiProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + } + if strings.EqualFold(name, "vertex") { + return NewVertexProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + } if oauthProvider == defaultCodexOAuthProvider { return NewCodexProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil } @@ -2647,6 +2749,9 @@ func CreateProviderByName(cfg *config.Config, name string) (LLMProvider, error) if oauthProvider == defaultKimiOAuthProvider { return NewKimiProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil } + if oauthProvider == defaultIFlowOAuthProvider || strings.EqualFold(name, defaultIFlowOAuthProvider) { + return NewIFlowProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil + } return NewHTTPProvider(name, pc.APIKey, pc.APIBase, defaultModel, pc.SupportsResponsesCompact, pc.Auth, time.Duration(pc.TimeoutSec)*time.Second, oauth), nil } diff --git a/pkg/providers/iflow_provider.go b/pkg/providers/iflow_provider.go new file mode 100644 index 0000000..8d3f77f --- /dev/null +++ b/pkg/providers/iflow_provider.go @@ -0,0 +1,352 @@ +package providers + +import ( + "bytes" + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + "github.com/google/uuid" +) + +const ( + iflowCompatBaseURL = "https://apis.iflow.cn/v1" + iflowCompatEndpoint = "/chat/completions" + iflowCompatUserAgent = "iFlow-Cli" +) + +type IFlowProvider struct { + base *HTTPProvider +} + +func NewIFlowProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *IFlowProvider { + return &IFlowProvider{base: NewHTTPProvider(providerName, apiKey, apiBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth)} +} + +func (p *IFlowProvider) GetDefaultModel() string { return openAICompatDefaultModel(p.base) } + +func (p *IFlowProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + if p == nil || p.base == nil { + return nil, fmt.Errorf("provider not configured") + } + body := buildIFlowChatRequest(p.base, messages, tools, model, options, false) + respBody, statusCode, contentType, err := doIFlowJSONWithAttempts(ctx, p.base, body) + if err != nil { + return nil, err + } + if statusCode != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(respBody)) + } + if !json.Valid(respBody) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(respBody)) + } + return parseOpenAICompatResponse(respBody) +} + +func (p *IFlowProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) { + if p == nil || p.base == nil { + return nil, fmt.Errorf("provider not configured") + } + if onDelta == nil { + onDelta = func(string) {} + } + body := buildIFlowChatRequest(p.base, messages, tools, model, options, true) + respBody, statusCode, contentType, err := doIFlowStreamWithAttempts(ctx, p.base, body, onDelta) + if err != nil { + return nil, err + } + if statusCode != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(respBody)) + } + if !json.Valid(respBody) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(respBody)) + } + return parseOpenAICompatResponse(respBody) +} + +func (p *IFlowProvider) CountTokens(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*UsageInfo, error) { + if p == nil || p.base == nil { + return nil, fmt.Errorf("provider not configured") + } + body := buildIFlowChatRequest(p.base, messages, tools, model, options, false) + count, err := estimateOpenAICompatTokenCount(body) + if err != nil { + return nil, err + } + return &UsageInfo{ + PromptTokens: count, + TotalTokens: count, + }, nil +} + +func buildIFlowChatRequest(base *HTTPProvider, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool) map[string]interface{} { + baseModel := qwenBaseModel(model) + body := base.buildOpenAICompatChatRequest(messages, tools, baseModel, options) + if stream { + body["stream"] = true + body["stream_options"] = map[string]interface{}{"include_usage": true} + iflowEnsureToolsArray(body) + } + applyIFlowThinking(body, model) + return body +} + +func applyIFlowThinking(body map[string]interface{}, model string) { + enabled, ok := iflowThinkingEnabled(model, body) + if !ok { + return + } + lowerModel := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", body["model"]))) + if strings.HasPrefix(lowerModel, "minimax") { + body["reasoning_split"] = enabled + return + } + kwargs, _ := body["chat_template_kwargs"].(map[string]interface{}) + if kwargs == nil { + kwargs = map[string]interface{}{} + } + kwargs["enable_thinking"] = enabled + delete(kwargs, "clear_thinking") + if enabled && strings.HasPrefix(lowerModel, "glm") { + kwargs["clear_thinking"] = false + } + body["chat_template_kwargs"] = kwargs +} + +func iflowThinkingEnabled(model string, body map[string]interface{}) (bool, bool) { + if suffix := strings.ToLower(strings.TrimSpace(qwenModelSuffix(model))); suffix != "" { + switch suffix { + case "none": + return false, true + default: + return true, true + } + } + if effort := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", body["reasoning_effort"]))); effort != "" { + return effort != "none", true + } + if thinking, ok := body["thinking"].(map[string]interface{}); ok { + typ := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", thinking["type"]))) + if typ == "disabled" { + return false, true + } + if budget, ok := thinking["budget_tokens"]; ok { + return intValue(budget) > 0, true + } + if typ != "" { + return true, true + } + } + return false, false +} + +func iflowEnsureToolsArray(body map[string]interface{}) { + if _, exists := body["tools"]; !exists { + body["tools"] = []map[string]interface{}{} + } + switch tools := body["tools"].(type) { + case []map[string]interface{}: + if len(tools) > 0 { + return + } + case []interface{}: + if len(tools) > 0 { + return + } + default: + return + } + body["tools"] = []map[string]interface{}{ + { + "type": "function", + "function": map[string]interface{}{ + "name": "noop", + "description": "Placeholder tool to stabilise streaming", + "parameters": map[string]interface{}{ + "type": "object", + }, + }, + }, + } +} + +func doIFlowJSONWithAttempts(ctx context.Context, base *HTTPProvider, payload map[string]interface{}) ([]byte, int, string, error) { + if base == nil { + return nil, 0, "", fmt.Errorf("provider not configured") + } + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err) + } + attempts, err := base.authAttempts(ctx) + if err != nil { + return nil, 0, "", err + } + var lastBody []byte + var lastStatus int + var lastType string + for _, attempt := range attempts { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointFor(iflowBaseURLForAttempt(base, attempt), iflowCompatEndpoint), bytes.NewReader(jsonData)) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + applyAttemptAuth(req, attempt) + applyIFlowHeaders(req, iflowAttemptAPIKey(attempt), false) + + body, status, contentType, err := base.doJSONAttempt(req, attempt) + if err != nil { + return nil, 0, "", err + } + reason, retry := classifyOAuthFailure(status, body) + if !retry { + base.markAttemptSuccess(attempt) + return body, status, contentType, nil + } + lastBody, lastStatus, lastType = body, status, contentType + applyAttemptFailure(base, attempt, reason, nil) + } + return lastBody, lastStatus, lastType, nil +} + +func doIFlowStreamWithAttempts(ctx context.Context, base *HTTPProvider, payload map[string]interface{}, onDelta func(string)) ([]byte, int, string, error) { + if base == nil { + return nil, 0, "", fmt.Errorf("provider not configured") + } + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err) + } + attempts, err := base.authAttempts(ctx) + if err != nil { + return nil, 0, "", err + } + var lastBody []byte + var lastStatus int + var lastType string + for _, attempt := range attempts { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointFor(iflowBaseURLForAttempt(base, attempt), iflowCompatEndpoint), bytes.NewReader(jsonData)) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + applyAttemptAuth(req, attempt) + applyIFlowHeaders(req, iflowAttemptAPIKey(attempt), true) + + body, status, contentType, quotaHit, err := base.doStreamAttempt(req, attempt, func(event string) { + var obj map[string]interface{} + if err := json.Unmarshal([]byte(event), &obj); err != nil { + return + } + choices, _ := obj["choices"].([]interface{}) + for _, choice := range choices { + item, _ := choice.(map[string]interface{}) + delta, _ := item["delta"].(map[string]interface{}) + if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["content"])); txt != "" { + onDelta(txt) + } + } + }) + if err != nil { + return nil, 0, "", err + } + if !quotaHit { + base.markAttemptSuccess(attempt) + return body, status, contentType, nil + } + lastBody, lastStatus, lastType = body, status, contentType + reason, _ := classifyOAuthFailure(status, body) + applyAttemptFailure(base, attempt, reason, nil) + } + return lastBody, lastStatus, lastType, nil +} + +func iflowBaseURLForAttempt(base *HTTPProvider, attempt authAttempt) string { + if attempt.session != nil { + if raw := strings.TrimSpace(attempt.session.ResourceURL); raw != "" { + return normalizeIFlowBaseURL(raw) + } + if attempt.session.Token != nil { + if raw := strings.TrimSpace(asString(attempt.session.Token["base_url"])); raw != "" { + return normalizeIFlowBaseURL(raw) + } + if raw := strings.TrimSpace(asString(attempt.session.Token["resource_url"])); raw != "" { + return normalizeIFlowBaseURL(raw) + } + } + } + if base != nil && strings.TrimSpace(base.apiBase) != "" && !strings.Contains(strings.ToLower(base.apiBase), "api.openai.com") { + return normalizeIFlowBaseURL(base.apiBase) + } + return iflowCompatBaseURL +} + +func normalizeIFlowBaseURL(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return iflowCompatBaseURL + } + if !strings.Contains(trimmed, "://") { + trimmed = "https://" + trimmed + } + trimmed = normalizeAPIBase(trimmed) + if strings.HasSuffix(strings.ToLower(trimmed), "/chat/completions") { + trimmed = strings.TrimSuffix(trimmed, "/chat/completions") + } + if !strings.HasSuffix(strings.ToLower(trimmed), "/v1") { + trimmed = strings.TrimRight(trimmed, "/") + "/v1" + } + return trimmed +} + +func iflowAttemptAPIKey(attempt authAttempt) string { + if attempt.session != nil && attempt.session.Token != nil { + if v := strings.TrimSpace(asString(attempt.session.Token["api_key"])); v != "" { + return v + } + if v := strings.TrimSpace(asString(attempt.session.Token["apiKey"])); v != "" { + return v + } + } + return strings.TrimSpace(attempt.token) +} + +func applyIFlowHeaders(req *http.Request, apiKey string, stream bool) { + if req == nil { + return + } + if strings.TrimSpace(apiKey) != "" { + req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(apiKey)) + } + req.Header.Set("User-Agent", iflowCompatUserAgent) + sessionID := "session-" + uuid.New().String() + req.Header.Set("session-id", sessionID) + timestamp := time.Now().UnixMilli() + req.Header.Set("x-iflow-timestamp", fmt.Sprintf("%d", timestamp)) + if sig := createIFlowSignature(iflowCompatUserAgent, sessionID, timestamp, apiKey); sig != "" { + req.Header.Set("x-iflow-signature", sig) + } + if stream { + req.Header.Set("Accept", "text/event-stream") + } else { + req.Header.Set("Accept", "application/json") + } +} + +func createIFlowSignature(userAgent, sessionID string, timestamp int64, apiKey string) string { + if strings.TrimSpace(apiKey) == "" { + return "" + } + payload := fmt.Sprintf("%s:%s:%d", userAgent, sessionID, timestamp) + h := hmac.New(sha256.New, []byte(apiKey)) + _, _ = h.Write([]byte(payload)) + return hex.EncodeToString(h.Sum(nil)) +} diff --git a/pkg/providers/iflow_provider_test.go b/pkg/providers/iflow_provider_test.go new file mode 100644 index 0000000..bc61798 --- /dev/null +++ b/pkg/providers/iflow_provider_test.go @@ -0,0 +1,130 @@ +package providers + +import ( + "net/http" + "strings" + "testing" + "time" +) + +func TestBuildIFlowChatRequestAppliesGLMThinking(t *testing.T) { + base := NewHTTPProvider("iflow", "token", iflowCompatBaseURL, "glm-4.6", false, "api_key", 5*time.Second, nil) + body := buildIFlowChatRequest(base, []Message{{Role: "user", Content: "hi"}}, nil, "glm-4.6(high)", nil, false) + + if got := body["model"]; got != "glm-4.6" { + t.Fatalf("model = %#v, want glm-4.6", got) + } + kwargs, _ := body["chat_template_kwargs"].(map[string]interface{}) + if got := kwargs["enable_thinking"]; got != true { + t.Fatalf("enable_thinking = %#v, want true", got) + } + if got := kwargs["clear_thinking"]; got != false { + t.Fatalf("clear_thinking = %#v, want false", got) + } +} + +func TestBuildIFlowChatRequestAppliesMiniMaxThinking(t *testing.T) { + base := NewHTTPProvider("iflow", "token", iflowCompatBaseURL, "minimax-m2", false, "api_key", 5*time.Second, nil) + body := buildIFlowChatRequest(base, []Message{{Role: "user", Content: "hi"}}, nil, "minimax-m2(none)", nil, false) + + if got := body["reasoning_split"]; got != false { + t.Fatalf("reasoning_split = %#v, want false", got) + } +} + +func TestIFlowEnsureToolsArrayAddsPlaceholder(t *testing.T) { + body := map[string]interface{}{"tools": []interface{}{}} + iflowEnsureToolsArray(body) + + tools, _ := body["tools"].([]map[string]interface{}) + if len(tools) != 1 { + t.Fatalf("tools = %#v, want placeholder tool", body["tools"]) + } + fn, _ := tools[0]["function"].(map[string]interface{}) + if got := fn["name"]; got != "noop" { + t.Fatalf("tool name = %#v, want noop", got) + } +} + +func TestIFlowEnsureToolsArrayAddsPlaceholderWhenMissing(t *testing.T) { + body := map[string]interface{}{} + iflowEnsureToolsArray(body) + + tools, _ := body["tools"].([]map[string]interface{}) + if len(tools) != 1 { + t.Fatalf("tools = %#v, want placeholder tool", body["tools"]) + } + fn, _ := tools[0]["function"].(map[string]interface{}) + if got := fn["name"]; got != "noop" { + t.Fatalf("tool name = %#v, want noop", got) + } +} + +func TestCreateIFlowSignature(t *testing.T) { + got := createIFlowSignature(iflowCompatUserAgent, "session-1", 1234567890, "secret") + want := "e42963e253333206027e32351580e1c1846b63936c65aed385cd41095aa516e9" + if got != want { + t.Fatalf("signature = %q, want %q", got, want) + } +} + +func TestApplyIFlowHeadersUsesSessionAPIKey(t *testing.T) { + attempt := authAttempt{ + kind: "oauth", + token: "access-token", + session: &oauthSession{ + Token: map[string]interface{}{"api_key": "session-api-key"}, + }, + } + req, err := http.NewRequest(http.MethodPost, iflowCompatBaseURL+iflowCompatEndpoint, nil) + if err != nil { + t.Fatalf("new request: %v", err) + } + applyAttemptAuth(req, attempt) + applyIFlowHeaders(req, iflowAttemptAPIKey(attempt), false) + if got := req.Header.Get("Authorization"); got != "Bearer session-api-key" { + t.Fatalf("Authorization = %q, want Bearer session-api-key", got) + } + if got := req.Header.Get("User-Agent"); got != iflowCompatUserAgent { + t.Fatalf("User-Agent = %q", got) + } + if got := req.Header.Get("session-id"); !strings.HasPrefix(got, "session-") { + t.Fatalf("session-id = %q", got) + } + if got := req.Header.Get("x-iflow-signature"); got == "" { + t.Fatal("expected x-iflow-signature") + } +} + +func TestBuildIFlowChatRequestStreamAddsPlaceholderToolWhenMissing(t *testing.T) { + base := NewHTTPProvider("iflow", "token", iflowCompatBaseURL, "glm-4.6", false, "api_key", 5*time.Second, nil) + body := buildIFlowChatRequest(base, []Message{{Role: "user", Content: "hi"}}, nil, "glm-4.6", nil, true) + + if got := body["stream"]; got != true { + t.Fatalf("stream = %#v, want true", got) + } + tools, _ := body["tools"].([]map[string]interface{}) + if len(tools) != 1 { + t.Fatalf("tools = %#v, want placeholder tool", body["tools"]) + } + fn, _ := tools[0]["function"].(map[string]interface{}) + if got := fn["name"]; got != "noop" { + t.Fatalf("tool name = %#v, want noop", got) + } +} + +func TestNormalizeIFlowBaseURL(t *testing.T) { + tests := []struct { + in string + want string + }{ + {in: "apis.iflow.cn", want: "https://apis.iflow.cn/v1"}, + {in: "https://apis.iflow.cn/v1", want: "https://apis.iflow.cn/v1"}, + {in: "https://apis.iflow.cn/v1/chat/completions", want: "https://apis.iflow.cn/v1"}, + } + for _, tt := range tests { + if got := normalizeIFlowBaseURL(tt.in); got != tt.want { + t.Fatalf("normalizeIFlowBaseURL(%q) = %q, want %q", tt.in, got, tt.want) + } + } +} diff --git a/pkg/providers/kimi_provider.go b/pkg/providers/kimi_provider.go new file mode 100644 index 0000000..029536a --- /dev/null +++ b/pkg/providers/kimi_provider.go @@ -0,0 +1,352 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + "time" +) + +type KimiProvider struct { + base *HTTPProvider +} + +func NewKimiProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *KimiProvider { + return &KimiProvider{base: NewHTTPProvider(providerName, apiKey, apiBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth)} +} + +func (p *KimiProvider) GetDefaultModel() string { return openAICompatDefaultModel(p.base) } + +func (p *KimiProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + if p == nil || p.base == nil { + return nil, fmt.Errorf("provider not configured") + } + body := buildKimiChatRequest(p.base, messages, tools, model, options, false) + respBody, statusCode, contentType, err := doOpenAICompatJSONWithAttempts(ctx, p.base, "/chat/completions", body, kimiProviderHooks{}) + if err != nil { + return nil, err + } + if statusCode != 200 { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(respBody)) + } + if !json.Valid(respBody) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(respBody)) + } + return parseOpenAICompatResponse(respBody) +} + +func (p *KimiProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) { + if p == nil || p.base == nil { + return nil, fmt.Errorf("provider not configured") + } + if onDelta == nil { + onDelta = func(string) {} + } + body := buildKimiChatRequest(p.base, messages, tools, model, options, true) + respBody, statusCode, contentType, err := doOpenAICompatStreamWithAttempts(ctx, p.base, "/chat/completions", body, onDelta, kimiProviderHooks{}) + if err != nil { + return nil, err + } + if statusCode != 200 { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(respBody)) + } + if !json.Valid(respBody) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(respBody)) + } + return parseOpenAICompatResponse(respBody) +} + +func (p *KimiProvider) CountTokens(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*UsageInfo, error) { + if p == nil || p.base == nil { + return nil, fmt.Errorf("provider not configured") + } + body := buildKimiChatRequest(p.base, messages, tools, model, options, false) + count, err := estimateOpenAICompatTokenCount(body) + if err != nil { + return nil, err + } + return &UsageInfo{ + PromptTokens: count, + TotalTokens: count, + }, nil +} + +type kimiProviderHooks struct{} + +func (kimiProviderHooks) beforeAttempt(authAttempt) (int, []byte, string, bool) { + return 0, nil, "", false +} + +func (kimiProviderHooks) endpoint(base *HTTPProvider, attempt authAttempt, path string) string { + return endpointFor(kimiBaseURLForAttempt(base, attempt), path) +} + +func (kimiProviderHooks) classifyFailure(status int, body []byte) (int, oauthFailureReason, bool, *time.Duration) { + reason, retry := classifyOAuthFailure(status, body) + return status, reason, retry, nil +} + +func (kimiProviderHooks) afterFailure(base *HTTPProvider, attempt authAttempt, reason oauthFailureReason, retryAfter *time.Duration) { + applyAttemptFailure(base, attempt, reason, retryAfter) +} + +func buildKimiChatRequest(base *HTTPProvider, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool) map[string]interface{} { + baseModel := stripKimiPrefix(qwenBaseModel(model)) + body := base.buildOpenAICompatChatRequest(messages, tools, baseModel, options) + if stream { + body["stream"] = true + body["stream_options"] = map[string]interface{}{"include_usage": true} + } + applyKimiThinking(body, model) + normalizeKimiToolMessages(body) + return body +} + +func stripKimiPrefix(model string) string { + trimmed := strings.TrimSpace(model) + if strings.HasPrefix(strings.ToLower(trimmed), "kimi-") { + return trimmed[5:] + } + return trimmed +} + +func applyKimiThinking(body map[string]interface{}, model string) { + suffix := qwenModelSuffix(model) + if suffix == "" { + return + } + suffix = strings.ToLower(strings.TrimSpace(suffix)) + switch suffix { + case "low", "medium", "high", "auto": + body["reasoning_effort"] = suffix + delete(body, "thinking") + case "none": + delete(body, "reasoning_effort") + body["thinking"] = map[string]interface{}{"type": "disabled"} + default: + if budget, err := parsePositiveInt(suffix); err == nil && budget > 0 { + delete(body, "reasoning_effort") + body["thinking"] = map[string]interface{}{ + "type": "enabled", + "budget_tokens": budget, + } + } + } +} + +func parsePositiveInt(raw string) (int, error) { + var out int + for _, ch := range raw { + if ch < '0' || ch > '9' { + return 0, fmt.Errorf("non-digit") + } + out = out*10 + int(ch-'0') + } + if out <= 0 { + return 0, fmt.Errorf("not positive") + } + return out, nil +} + +func normalizeKimiToolMessages(body map[string]interface{}) { + var items []map[string]interface{} + switch raw := body["messages"].(type) { + case []map[string]interface{}: + items = raw + case []interface{}: + items = make([]map[string]interface{}, 0, len(raw)) + for _, item := range raw { + msg, _ := item.(map[string]interface{}) + if msg != nil { + items = append(items, msg) + } + } + } + if len(items) == 0 { + return + } + pending := make([]string, 0) + latestReasoning := "" + hasLatestReasoning := false + for i := range items { + msg := items[i] + role := strings.TrimSpace(fmt.Sprintf("%v", msg["role"])) + switch role { + case "assistant": + if raw, ok := msg["reasoning_content"]; ok { + if reasoning := strings.TrimSpace(fmt.Sprintf("%v", raw)); reasoning != "" && reasoning != "" { + latestReasoning = reasoning + hasLatestReasoning = true + } + } + var toolCallIDs []string + switch raw := msg["tool_calls"].(type) { + case []interface{}: + for _, item := range raw { + tc, _ := item.(map[string]interface{}) + if id := strings.TrimSpace(fmt.Sprintf("%v", tc["id"])); id != "" { + toolCallIDs = append(toolCallIDs, id) + } + } + case []map[string]interface{}: + for _, tc := range raw { + if id := strings.TrimSpace(fmt.Sprintf("%v", tc["id"])); id != "" { + toolCallIDs = append(toolCallIDs, id) + } + } + } + if len(toolCallIDs) == 0 { + continue + } + existingReasoning := "" + if raw, ok := msg["reasoning_content"]; ok { + existingReasoning = strings.TrimSpace(fmt.Sprintf("%v", raw)) + } + if existingReasoning == "" || existingReasoning == "" { + msg["reasoning_content"] = fallbackKimiAssistantReasoning(msg, hasLatestReasoning, latestReasoning) + } + for _, id := range toolCallIDs { + pending = append(pending, id) + } + case "tool": + if raw, ok := msg["tool_call_id"]; ok { + if id := strings.TrimSpace(fmt.Sprintf("%v", raw)); id != "" && id != "" { + pending = removePendingToolID(pending, id) + continue + } + } + if raw, ok := msg["call_id"]; ok { + if callID := strings.TrimSpace(fmt.Sprintf("%v", raw)); callID != "" && callID != "" { + msg["tool_call_id"] = callID + pending = removePendingToolID(pending, callID) + continue + } + } + if len(pending) == 1 { + msg["tool_call_id"] = pending[0] + pending = pending[:0] + } + } + } +} + +func removePendingToolID(pending []string, want string) []string { + for i := range pending { + if pending[i] == want { + return append(pending[:i], pending[i+1:]...) + } + } + return pending +} + +func fallbackKimiAssistantReasoning(msg map[string]interface{}, hasLatest bool, latest string) string { + if hasLatest && strings.TrimSpace(latest) != "" { + return latest + } + if text := strings.TrimSpace(fmt.Sprintf("%v", msg["content"])); text != "" { + return text + } + parts := make([]string, 0) + switch content := msg["content"].(type) { + case []map[string]interface{}: + for _, part := range content { + text := strings.TrimSpace(fmt.Sprintf("%v", part["text"])) + if text != "" { + parts = append(parts, text) + } + } + case []interface{}: + for _, raw := range content { + part, _ := raw.(map[string]interface{}) + text := strings.TrimSpace(fmt.Sprintf("%v", part["text"])) + if text != "" { + parts = append(parts, text) + } + } + } + if len(parts) > 0 { + return strings.Join(parts, "\n") + } + return "[reasoning unavailable]" +} + +func kimiBaseURLForAttempt(base *HTTPProvider, attempt authAttempt) string { + if base == nil { + return kimiCompatBaseURL + } + if strings.TrimSpace(base.apiBase) != "" && !strings.Contains(strings.ToLower(base.apiBase), "api.openai.com") { + return normalizeAPIBase(base.apiBase) + } + if attempt.session != nil && strings.TrimSpace(attempt.session.ResourceURL) != "" { + return normalizeKimiResourceURL(attempt.session.ResourceURL) + } + return kimiCompatBaseURL +} + +func normalizeKimiResourceURL(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return kimiCompatBaseURL + } + lower := strings.ToLower(trimmed) + switch { + case strings.HasSuffix(lower, "/v1"): + if strings.HasPrefix(lower, "http://") || strings.HasPrefix(lower, "https://") { + return normalizeAPIBase(trimmed) + } + return normalizeAPIBase("https://" + trimmed) + case strings.HasSuffix(lower, "/coding"): + base := trimmed + "/v1" + if strings.HasPrefix(lower, "http://") || strings.HasPrefix(lower, "https://") { + return normalizeAPIBase(base) + } + return normalizeAPIBase("https://" + base) + case strings.HasPrefix(lower, "http://"), strings.HasPrefix(lower, "https://"): + return normalizeAPIBase(trimmed + "/v1") + default: + return normalizeAPIBase("https://" + trimmed + "/v1") + } +} + +func kimiDeviceName() string { + hostname, err := os.Hostname() + if err != nil || strings.TrimSpace(hostname) == "" { + return "clawgo" + } + return hostname +} + +func kimiDeviceModel() string { + return runtime.GOOS + " " + runtime.GOARCH +} + +func kimiDeviceID(session *oauthSession) string { + if session != nil && strings.TrimSpace(session.DeviceID) != "" { + return strings.TrimSpace(session.DeviceID) + } + if homeDir, err := os.UserHomeDir(); err == nil && strings.TrimSpace(homeDir) != "" { + var base string + switch runtime.GOOS { + case "darwin": + base = filepath.Join(homeDir, "Library", "Application Support", "kimi") + case "windows": + appData := os.Getenv("APPDATA") + if appData == "" { + appData = filepath.Join(homeDir, "AppData", "Roaming") + } + base = filepath.Join(appData, "kimi") + default: + base = filepath.Join(homeDir, ".local", "share", "kimi") + } + if data, err := os.ReadFile(filepath.Join(base, "device_id")); err == nil { + if id := strings.TrimSpace(string(data)); id != "" { + return id + } + } + } + return "clawgo-device" +} diff --git a/pkg/providers/kimi_provider_test.go b/pkg/providers/kimi_provider_test.go new file mode 100644 index 0000000..04b59dc --- /dev/null +++ b/pkg/providers/kimi_provider_test.go @@ -0,0 +1,145 @@ +package providers + +import ( + "net/http" + "testing" + "time" +) + +func TestBuildKimiChatRequestStripsPrefixAndAppliesThinking(t *testing.T) { + base := NewHTTPProvider("kimi", "token", kimiCompatBaseURL, "kimi-k2.5", false, "oauth", 5*time.Second, nil) + body := buildKimiChatRequest(base, []Message{{Role: "user", Content: "hi"}}, nil, "kimi-k2.5(high)", nil, false) + + if got := body["model"]; got != "k2.5" { + t.Fatalf("model = %#v, want k2.5", got) + } + if got := body["reasoning_effort"]; got != "high" { + t.Fatalf("reasoning_effort = %#v, want high", got) + } +} + +func TestBuildKimiChatRequestDisablesThinking(t *testing.T) { + base := NewHTTPProvider("kimi", "token", kimiCompatBaseURL, "kimi-k2.5", false, "oauth", 5*time.Second, nil) + body := buildKimiChatRequest(base, []Message{{Role: "user", Content: "hi"}}, nil, "kimi-k2.5(none)", nil, false) + + thinking, _ := body["thinking"].(map[string]interface{}) + if got := thinking["type"]; got != "disabled" { + t.Fatalf("thinking.type = %#v, want disabled", got) + } +} + +func TestNormalizeKimiToolMessagesBackfillsToolCallID(t *testing.T) { + body := map[string]interface{}{ + "messages": []map[string]interface{}{ + { + "role": "assistant", + "tool_calls": []map[string]interface{}{ + {"id": "call_1"}, + }, + "content": "thinking content", + }, + { + "role": "tool", + "content": "tool result", + }, + }, + } + + normalizeKimiToolMessages(body) + + msgs := body["messages"].([]map[string]interface{}) + if got := msgs[1]["tool_call_id"]; got != "call_1" { + t.Fatalf("tool_call_id = %#v, want call_1", got) + } + if got := msgs[0]["reasoning_content"]; got != "thinking content" { + t.Fatalf("reasoning_content = %#v, want fallback content", got) + } +} + +func TestNormalizeKimiToolMessagesPromotesCallID(t *testing.T) { + body := map[string]interface{}{ + "messages": []map[string]interface{}{ + { + "role": "tool", + "call_id": "call_2", + }, + }, + } + + normalizeKimiToolMessages(body) + + msgs := body["messages"].([]map[string]interface{}) + if got := msgs[0]["tool_call_id"]; got != "call_2" { + t.Fatalf("tool_call_id = %#v, want call_2", got) + } +} + +func TestApplyAttemptProviderHeadersKimiUsesResolvedDeviceFields(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, kimiCompatBaseURL+"/chat/completions", nil) + if err != nil { + t.Fatalf("new request: %v", err) + } + provider := &HTTPProvider{ + oauth: &oauthManager{cfg: oauthConfig{Provider: defaultKimiOAuthProvider}}, + } + attempt := authAttempt{ + kind: "oauth", + token: "kimi-token", + session: &oauthSession{ + DeviceID: "device-123", + }, + } + + applyAttemptProviderHeaders(req, attempt, provider, true) + + if got := req.Header.Get("X-Msh-Device-Id"); got != "device-123" { + t.Fatalf("X-Msh-Device-Id = %q, want device-123", got) + } + if got := req.Header.Get("X-Msh-Device-Model"); got != kimiDeviceModel() { + t.Fatalf("X-Msh-Device-Model = %q, want %q", got, kimiDeviceModel()) + } + if got := req.Header.Get("X-Msh-Device-Name"); got == "" { + t.Fatal("expected X-Msh-Device-Name to be set") + } +} + +func TestKimiHookUsesSessionResourceURL(t *testing.T) { + hooks := kimiProviderHooks{} + base := NewHTTPProvider("kimi", "token", kimiCompatBaseURL, "kimi-k2.5", false, "oauth", 5*time.Second, nil) + got := hooks.endpoint(base, authAttempt{ + kind: "oauth", + session: &oauthSession{ + ResourceURL: "https://api.kimi.com/coding/v1", + }, + }, "/chat/completions") + if got != "https://api.kimi.com/coding/v1/chat/completions" { + t.Fatalf("endpoint = %q", got) + } +} + +func TestNormalizeKimiResourceURL(t *testing.T) { + tests := []struct { + in string + want string + }{ + {in: "https://api.kimi.com/coding", want: "https://api.kimi.com/coding/v1"}, + {in: "api.kimi.com/coding", want: "https://api.kimi.com/coding/v1"}, + {in: "https://api.kimi.com/coding/v1", want: "https://api.kimi.com/coding/v1"}, + } + for _, tt := range tests { + if got := normalizeKimiResourceURL(tt.in); got != tt.want { + t.Fatalf("normalizeKimiResourceURL(%q) = %q, want %q", tt.in, got, tt.want) + } + } +} + +func TestKimiProviderCountTokens(t *testing.T) { + provider := NewKimiProvider("kimi", "token", kimiCompatBaseURL, "kimi-k2.5", false, "api_key", 5*time.Second, nil) + usage, err := provider.CountTokens(t.Context(), []Message{{Role: "user", Content: "hello kimi"}}, nil, "kimi-k2.5", nil) + if err != nil { + t.Fatalf("CountTokens error: %v", err) + } + if usage == nil || usage.PromptTokens <= 0 || usage.TotalTokens != usage.PromptTokens { + t.Fatalf("usage = %#v, want positive prompt-only count", usage) + } +} diff --git a/pkg/providers/oauth.go b/pkg/providers/oauth.go index 2cb9d69..344b87a 100644 --- a/pkg/providers/oauth.go +++ b/pkg/providers/oauth.go @@ -57,6 +57,13 @@ const ( defaultKimiDeviceCodeURL = "https://auth.kimi.com/api/oauth/device_authorization" defaultKimiTokenURL = "https://auth.kimi.com/api/oauth/token" defaultKimiClientID = "17e5f671-d194-4dfb-9706-5516cb48c098" + defaultIFlowOAuthProvider = "iflow" + defaultIFlowAuthURL = "https://iflow.cn/oauth" + defaultIFlowTokenURL = "https://iflow.cn/oauth/token" + defaultIFlowClientID = "10009311001" + defaultIFlowClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW" + defaultIFlowCallbackPort = 11451 + defaultIFlowRedirectPath = "/oauth2callback" defaultQwenOAuthProvider = "qwen" defaultQwenDeviceCodeURL = "https://chat.qwen.ai/api/v1/oauth2/device/code" defaultQwenTokenURL = "https://chat.qwen.ai/api/v1/oauth2/token" @@ -649,6 +656,13 @@ func resolveOAuthConfig(pc config.ProviderConfig) (oauthConfig, error) { cfg.DeviceCodeURL = firstNonEmpty(strings.TrimSpace(pc.OAuth.AuthURL), defaultKimiDeviceCodeURL) cfg.TokenURL = firstNonEmpty(cfg.TokenURL, defaultKimiTokenURL) cfg.AuthURL = cfg.DeviceCodeURL + case defaultIFlowOAuthProvider: + cfg.CallbackPort = defaultInt(cfg.CallbackPort, defaultIFlowCallbackPort) + cfg.ClientID = firstNonEmpty(cfg.ClientID, defaultIFlowClientID) + cfg.ClientSecret = firstNonEmpty(cfg.ClientSecret, defaultIFlowClientSecret) + cfg.AuthURL = firstNonEmpty(cfg.AuthURL, defaultIFlowAuthURL) + cfg.TokenURL = firstNonEmpty(cfg.TokenURL, defaultIFlowTokenURL) + cfg.RedirectPath = defaultIFlowRedirectPath case defaultQwenOAuthProvider: cfg.FlowKind = oauthFlowDevice cfg.ClientID = firstNonEmpty(cfg.ClientID, defaultQwenClientID) @@ -700,6 +714,8 @@ func defaultRefreshLead(provider string, overrideSec int) time.Duration { return defaultAntigravityRefreshLead case defaultKimiOAuthProvider: return defaultKimiRefreshLead + case defaultIFlowOAuthProvider: + return 30 * time.Minute case defaultQwenOAuthProvider: return defaultQwenRefreshLead default: diff --git a/pkg/providers/openai_compat_provider.go b/pkg/providers/openai_compat_provider.go index b0adbeb..c349972 100644 --- a/pkg/providers/openai_compat_provider.go +++ b/pkg/providers/openai_compat_provider.go @@ -1,6 +1,7 @@ package providers import ( + "bytes" "context" "encoding/json" "fmt" @@ -9,41 +10,6 @@ import ( "time" ) -type QwenProvider struct { - base *HTTPProvider -} - -type KimiProvider struct { - base *HTTPProvider -} - -func NewQwenProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *QwenProvider { - return &QwenProvider{base: NewHTTPProvider(providerName, apiKey, apiBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth)} -} - -func NewKimiProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *KimiProvider { - return &KimiProvider{base: NewHTTPProvider(providerName, apiKey, apiBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth)} -} - -func (p *QwenProvider) GetDefaultModel() string { return openAICompatDefaultModel(p.base) } -func (p *KimiProvider) GetDefaultModel() string { return openAICompatDefaultModel(p.base) } - -func (p *QwenProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { - return runOpenAICompatChat(ctx, p.base, messages, tools, model, options) -} - -func (p *QwenProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) { - return runOpenAICompatChatStream(ctx, p.base, messages, tools, model, options, onDelta) -} - -func (p *KimiProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { - return runOpenAICompatChat(ctx, p.base, messages, tools, model, options) -} - -func (p *KimiProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) { - return runOpenAICompatChatStream(ctx, p.base, messages, tools, model, options, onDelta) -} - func openAICompatDefaultModel(base *HTTPProvider) string { if base == nil { return "" @@ -51,11 +17,50 @@ func openAICompatDefaultModel(base *HTTPProvider) string { return base.GetDefaultModel() } +func runQwenChat(ctx context.Context, base *HTTPProvider, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + if base == nil { + return nil, fmt.Errorf("provider not configured") + } + requestBody := buildQwenChatRequest(base, messages, tools, model, options, false) + body, statusCode, contentType, err := doOpenAICompatJSONWithAttempts(ctx, base, "/chat/completions", requestBody, qwenProviderHooks{}) + if err != nil { + return nil, err + } + if statusCode != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(body)) + } + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body)) + } + return parseOpenAICompatResponse(body) +} + +func runQwenChatStream(ctx context.Context, base *HTTPProvider, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) { + if base == nil { + return nil, fmt.Errorf("provider not configured") + } + if onDelta == nil { + onDelta = func(string) {} + } + requestBody := buildQwenChatRequest(base, messages, tools, model, options, true) + body, statusCode, contentType, err := doOpenAICompatStreamWithAttempts(ctx, base, "/chat/completions", requestBody, onDelta, qwenProviderHooks{}) + if err != nil { + return nil, err + } + if statusCode != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(body)) + } + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body)) + } + return parseOpenAICompatResponse(body) +} + func runOpenAICompatChat(ctx context.Context, base *HTTPProvider, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { if base == nil { return nil, fmt.Errorf("provider not configured") } - body, statusCode, contentType, err := base.postJSON(ctx, endpointFor(base.compatBase(), "/chat/completions"), base.buildOpenAICompatChatRequest(messages, tools, model, options)) + body, statusCode, contentType, err := doOpenAICompatJSONWithAttempts(ctx, base, "/chat/completions", base.buildOpenAICompatChatRequest(messages, tools, model, options), nil) if err != nil { return nil, err } @@ -78,20 +83,7 @@ func runOpenAICompatChatStream(ctx context.Context, base *HTTPProvider, messages chatBody := base.buildOpenAICompatChatRequest(messages, tools, model, options) chatBody["stream"] = true chatBody["stream_options"] = map[string]interface{}{"include_usage": true} - body, statusCode, contentType, err := base.postJSONStream(ctx, endpointFor(base.compatBase(), "/chat/completions"), chatBody, func(event string) { - var obj map[string]interface{} - if err := json.Unmarshal([]byte(event), &obj); err != nil { - return - } - choices, _ := obj["choices"].([]interface{}) - for _, choice := range choices { - item, _ := choice.(map[string]interface{}) - delta, _ := item["delta"].(map[string]interface{}) - if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["content"])); txt != "" { - onDelta(txt) - } - } - }) + body, statusCode, contentType, err := doOpenAICompatStreamWithAttempts(ctx, base, "/chat/completions", chatBody, onDelta, nil) if err != nil { return nil, err } @@ -103,3 +95,200 @@ func runOpenAICompatChatStream(ctx context.Context, base *HTTPProvider, messages } return parseOpenAICompatResponse(body) } + +type openAICompatHooks interface { + beforeAttempt(attempt authAttempt) (int, []byte, string, bool) + endpoint(base *HTTPProvider, attempt authAttempt, path string) string + classifyFailure(status int, body []byte) (int, oauthFailureReason, bool, *time.Duration) + afterFailure(base *HTTPProvider, attempt authAttempt, reason oauthFailureReason, retryAfter *time.Duration) +} + +func doOpenAICompatJSONWithAttempts(ctx context.Context, base *HTTPProvider, path string, payload map[string]interface{}, hooks openAICompatHooks) ([]byte, int, string, error) { + if base == nil { + return nil, 0, "", fmt.Errorf("provider not configured") + } + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err) + } + attempts, err := base.authAttempts(ctx) + if err != nil { + return nil, 0, "", err + } + var lastBody []byte + var lastStatus int + var lastType string + for _, attempt := range attempts { + attemptEndpoint := endpointFor(base.compatBase(), path) + if hooks != nil { + attemptEndpoint = hooks.endpoint(base, attempt, path) + } + if hooks != nil { + if status, body, contentType, blocked := hooks.beforeAttempt(attempt); blocked { + lastBody, lastStatus, lastType = body, status, contentType + _, reason, retry, retryAfter := hooks.classifyFailure(status, body) + if retry { + hooks.afterFailure(base, attempt, reason, retryAfter) + continue + } + return body, status, contentType, nil + } + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, attemptEndpoint, bytes.NewReader(jsonData)) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + applyAttemptAuth(req, attempt) + applyAttemptProviderHeaders(req, attempt, base, false) + + body, status, contentType, err := base.doJSONAttempt(req, attempt) + if err != nil { + return nil, 0, "", err + } + mappedStatus := status + reason, retry := oauthFailureReason(""), false + var retryAfter *time.Duration + if hooks != nil { + mappedStatus, reason, retry, retryAfter = hooks.classifyFailure(status, body) + } else { + reason, retry = classifyOAuthFailure(status, body) + } + if !retry { + base.markAttemptSuccess(attempt) + return body, mappedStatus, contentType, nil + } + lastBody, lastStatus, lastType = body, mappedStatus, contentType + if hooks != nil { + hooks.afterFailure(base, attempt, reason, retryAfter) + } else { + applyAttemptFailure(base, attempt, reason, nil) + } + } + return lastBody, lastStatus, lastType, nil +} + +func doOpenAICompatStreamWithAttempts(ctx context.Context, base *HTTPProvider, path string, payload map[string]interface{}, onDelta func(string), hooks openAICompatHooks) ([]byte, int, string, error) { + if base == nil { + return nil, 0, "", fmt.Errorf("provider not configured") + } + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err) + } + attempts, err := base.authAttempts(ctx) + if err != nil { + return nil, 0, "", err + } + var lastBody []byte + var lastStatus int + var lastType string + for _, attempt := range attempts { + attemptEndpoint := endpointFor(base.compatBase(), path) + if hooks != nil { + attemptEndpoint = hooks.endpoint(base, attempt, path) + } + if hooks != nil { + if status, body, contentType, blocked := hooks.beforeAttempt(attempt); blocked { + lastBody, lastStatus, lastType = body, status, contentType + _, reason, retry, retryAfter := hooks.classifyFailure(status, body) + if retry { + hooks.afterFailure(base, attempt, reason, retryAfter) + continue + } + return body, status, contentType, nil + } + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, attemptEndpoint, bytes.NewReader(jsonData)) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + applyAttemptAuth(req, attempt) + applyAttemptProviderHeaders(req, attempt, base, true) + + body, status, contentType, _, err := base.doStreamAttempt(req, attempt, func(event string) { + var obj map[string]interface{} + if err := json.Unmarshal([]byte(event), &obj); err != nil { + return + } + choices, _ := obj["choices"].([]interface{}) + for _, choice := range choices { + item, _ := choice.(map[string]interface{}) + delta, _ := item["delta"].(map[string]interface{}) + if txt := strings.TrimSpace(fmt.Sprintf("%v", delta["content"])); txt != "" { + onDelta(txt) + } + } + }) + if err != nil { + return nil, 0, "", err + } + mappedStatus := status + reason, retry := oauthFailureReason(""), false + var retryAfter *time.Duration + if hooks != nil { + mappedStatus, reason, retry, retryAfter = hooks.classifyFailure(status, body) + } else { + reason, retry = classifyOAuthFailure(status, body) + } + if !retry { + base.markAttemptSuccess(attempt) + return body, mappedStatus, contentType, nil + } + lastBody, lastStatus, lastType = body, mappedStatus, contentType + if hooks != nil { + hooks.afterFailure(base, attempt, reason, retryAfter) + } else { + applyAttemptFailure(base, attempt, reason, nil) + } + } + return lastBody, lastStatus, lastType, nil +} + +func applyAttemptFailure(base *HTTPProvider, attempt authAttempt, reason oauthFailureReason, retryAfter *time.Duration) { + if base == nil { + return + } + if attempt.kind == "oauth" && attempt.session != nil && base.oauth != nil { + if retryAfter != nil { + base.oauth.mu.Lock() + until := time.Now().Add(*retryAfter) + if strings.TrimSpace(attempt.session.FilePath) != "" { + base.oauth.cooldowns[strings.TrimSpace(attempt.session.FilePath)] = until + } + attempt.session.CooldownUntil = until.Format(time.RFC3339) + attempt.session.FailureCount++ + attempt.session.LastFailure = string(reason) + if attempt.session.HealthScore == 0 { + attempt.session.HealthScore = 100 + } + attempt.session.HealthScore = maxInt(1, attempt.session.HealthScore-healthPenaltyForReason(reason)) + base.oauth.mu.Unlock() + recordProviderOAuthError(base.providerName, attempt.session, reason) + recordProviderRuntimeChange(base.providerName, "oauth", firstNonEmpty(attempt.session.Email, attempt.session.AccountID, attempt.session.FilePath), "oauth_cooldown_"+string(reason), "oauth credential entered provider-specific cooldown after request failure") + return + } + base.oauth.markExhausted(attempt.session, reason) + recordProviderOAuthError(base.providerName, attempt.session, reason) + return + } + if attempt.kind == "api_key" { + base.markAPIKeyFailure(reason) + } +} + +func estimateOpenAICompatTokenCount(body map[string]interface{}) (int, error) { + data, err := json.Marshal(body) + if err != nil { + return 0, fmt.Errorf("failed to encode request for token count: %w", err) + } + const charsPerToken = 4 + count := (len(data) + charsPerToken - 1) / charsPerToken + if count < 1 { + count = 1 + } + return count, nil +} diff --git a/pkg/providers/openai_compat_provider_test.go b/pkg/providers/openai_compat_provider_test.go new file mode 100644 index 0000000..054f579 --- /dev/null +++ b/pkg/providers/openai_compat_provider_test.go @@ -0,0 +1,159 @@ +package providers + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestBuildQwenChatRequestStripsSuffixAndAppliesThinking(t *testing.T) { + base := NewHTTPProvider("qwen", "token", qwenCompatBaseURL, "qwen-max", false, "oauth", 5*time.Second, nil) + body := buildQwenChatRequest(base, []Message{{Role: "user", Content: "hi"}}, nil, "qwen-max(high)", nil, false) + + if got := body["model"]; got != "qwen-max" { + t.Fatalf("model = %#v, want qwen-max", got) + } + if got := body["reasoning_effort"]; got != "high" { + t.Fatalf("reasoning_effort = %#v, want high", got) + } +} + +func TestBuildQwenChatRequestAddsPoisonToolForStreamingWithoutTools(t *testing.T) { + base := NewHTTPProvider("qwen", "token", qwenCompatBaseURL, "qwen-max", false, "oauth", 5*time.Second, nil) + body := buildQwenChatRequest(base, []Message{{Role: "user", Content: "hi"}}, nil, "qwen-max", nil, true) + + tools, ok := body["tools"].([]map[string]interface{}) + if !ok || len(tools) != 1 { + t.Fatalf("tools = %#v, want single poison tool", body["tools"]) + } + function, _ := tools[0]["function"].(map[string]interface{}) + if got := function["name"]; got != "do_not_call_me" { + t.Fatalf("tool name = %#v, want do_not_call_me", got) + } +} + +func TestClassifyQwenFailureMapsQuotaTo429UntilNextMidnight(t *testing.T) { + status, reason, retry, retryAfter := classifyQwenFailure(http.StatusForbidden, []byte(`{"error":{"code":"insufficient_quota","message":"free allocated quota exceeded"}}`)) + if status != http.StatusTooManyRequests { + t.Fatalf("status = %d, want %d", status, http.StatusTooManyRequests) + } + if reason != oauthFailureQuota || !retry { + t.Fatalf("reason=%q retry=%v", reason, retry) + } + if retryAfter == nil || *retryAfter <= 0 || *retryAfter > 24*time.Hour { + t.Fatalf("retryAfter = %#v, want within next day", retryAfter) + } +} + +func TestQwenProviderChatMapsQuota403To429(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"error":{"code":"insufficient_quota","message":"free allocated quota exceeded"}}`)) + })) + defer server.Close() + + provider := NewQwenProvider("qwen-quota", "token", server.URL, "qwen-max", false, "api_key", 5*time.Second, nil) + _, err := provider.Chat(t.Context(), []Message{{Role: "user", Content: "hi"}}, nil, "qwen-max", nil) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "status 429") { + t.Fatalf("error = %v, want mapped 429", err) + } +} + +func TestQwenProviderCountTokens(t *testing.T) { + provider := NewQwenProvider("qwen", "token", qwenCompatBaseURL, "qwen-max", false, "api_key", 5*time.Second, nil) + usage, err := provider.CountTokens(t.Context(), []Message{{Role: "user", Content: "hello qwen"}}, nil, "qwen-max", nil) + if err != nil { + t.Fatalf("CountTokens error: %v", err) + } + if usage == nil || usage.PromptTokens <= 0 || usage.TotalTokens != usage.PromptTokens { + t.Fatalf("usage = %#v, want positive prompt-only count", usage) + } +} + +func TestApplyAttemptProviderHeadersQwenUsesDynamicStainlessValues(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, qwenCompatBaseURL+"/chat/completions", nil) + if err != nil { + t.Fatalf("new request: %v", err) + } + provider := &HTTPProvider{ + oauth: &oauthManager{cfg: oauthConfig{Provider: defaultQwenOAuthProvider}}, + } + + applyAttemptProviderHeaders(req, authAttempt{kind: "oauth", token: "qwen-token"}, provider, true) + + if got := req.Header.Get("X-Stainless-Arch"); got != qwenStainlessArch() { + t.Fatalf("X-Stainless-Arch = %q, want %q", got, qwenStainlessArch()) + } + if got := req.Header.Get("X-Stainless-Os"); got != qwenStainlessOS() { + t.Fatalf("X-Stainless-Os = %q, want %q", got, qwenStainlessOS()) + } + if got := req.Header.Get("Accept"); got != "text/event-stream" { + t.Fatalf("Accept = %q, want text/event-stream", got) + } +} + +func TestNormalizeQwenResourceURL(t *testing.T) { + tests := []struct { + in string + want string + }{ + {in: "https://chat.qwen.ai/api", want: "https://chat.qwen.ai/v1"}, + {in: "chat.qwen.ai/api", want: "https://chat.qwen.ai/v1"}, + {in: "https://portal.qwen.ai/v1", want: "https://portal.qwen.ai/v1"}, + } + for _, tt := range tests { + if got := normalizeQwenResourceURL(tt.in); got != tt.want { + t.Fatalf("normalizeQwenResourceURL(%q) = %q, want %q", tt.in, got, tt.want) + } + } +} + +func TestQwenHookUsesSessionResourceURL(t *testing.T) { + hooks := qwenProviderHooks{} + base := NewHTTPProvider("qwen", "token", qwenCompatBaseURL, "qwen-max", false, "oauth", 5*time.Second, nil) + got := hooks.endpoint(base, authAttempt{ + kind: "oauth", + session: &oauthSession{ + ResourceURL: "https://chat.qwen.ai/api", + }, + }, "/chat/completions") + if got != "https://chat.qwen.ai/v1/chat/completions" { + t.Fatalf("endpoint = %q", got) + } +} + +func TestOpenAICompatMessagesPreserveMultimodalContentParts(t *testing.T) { + msgs := openAICompatMessages([]Message{{ + Role: "user", + ContentParts: []MessageContentPart{ + {Type: "text", Text: "look"}, + {Type: "input_image", ImageURL: "https://example.com/cat.png", Detail: "high"}, + }, + }}) + if len(msgs) != 1 { + t.Fatalf("messages len = %d", len(msgs)) + } + content, ok := msgs[0]["content"].([]map[string]interface{}) + if !ok || len(content) != 2 { + t.Fatalf("content = %#v", msgs[0]["content"]) + } + if got := content[0]["type"]; got != "text" { + t.Fatalf("first part type = %#v", got) + } + imagePart, _ := content[1]["image_url"].(map[string]interface{}) + if got := content[1]["type"]; got != "image_url" { + t.Fatalf("second part type = %#v", got) + } + if got := imagePart["url"]; got != "https://example.com/cat.png" { + t.Fatalf("image url = %#v", got) + } + if got := imagePart["detail"]; got != "high" { + t.Fatalf("image detail = %#v", got) + } +} diff --git a/pkg/providers/qwen_provider.go b/pkg/providers/qwen_provider.go new file mode 100644 index 0000000..87494f8 --- /dev/null +++ b/pkg/providers/qwen_provider.go @@ -0,0 +1,323 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + "sync" + "time" +) + +const qwenRateLimitPerMin = 60 + +type QwenProvider struct { + base *HTTPProvider +} + +var ( + qwenBeijingLocation = func() *time.Location { + loc, err := time.LoadLocation("Asia/Shanghai") + if err != nil || loc == nil { + return time.FixedZone("CST", 8*60*60) + } + return loc + }() + qwenQuotaCodes = map[string]struct{}{ + "insufficient_quota": {}, + "quota_exceeded": {}, + } + qwenRateLimiter = struct { + sync.Mutex + requests map[string][]time.Time + }{ + requests: map[string][]time.Time{}, + } +) + +func NewQwenProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *QwenProvider { + return &QwenProvider{base: NewHTTPProvider(providerName, apiKey, apiBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth)} +} + +func (p *QwenProvider) GetDefaultModel() string { return openAICompatDefaultModel(p.base) } + +func (p *QwenProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + if p == nil || p.base == nil { + return nil, fmt.Errorf("provider not configured") + } + requestBody := buildQwenChatRequest(p.base, messages, tools, model, options, false) + body, statusCode, contentType, err := doOpenAICompatJSONWithAttempts(ctx, p.base, "/chat/completions", requestBody, qwenProviderHooks{}) + if err != nil { + return nil, err + } + if statusCode != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(body)) + } + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body)) + } + return parseOpenAICompatResponse(body) +} + +func (p *QwenProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) { + if p == nil || p.base == nil { + return nil, fmt.Errorf("provider not configured") + } + if onDelta == nil { + onDelta = func(string) {} + } + requestBody := buildQwenChatRequest(p.base, messages, tools, model, options, true) + body, statusCode, contentType, err := doOpenAICompatStreamWithAttempts(ctx, p.base, "/chat/completions", requestBody, onDelta, qwenProviderHooks{}) + if err != nil { + return nil, err + } + if statusCode != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", statusCode, contentType, previewResponseBody(body)) + } + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", statusCode, contentType, previewResponseBody(body)) + } + return parseOpenAICompatResponse(body) +} + +func (p *QwenProvider) CountTokens(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*UsageInfo, error) { + if p == nil || p.base == nil { + return nil, fmt.Errorf("provider not configured") + } + body := buildQwenChatRequest(p.base, messages, tools, model, options, false) + count, err := estimateOpenAICompatTokenCount(body) + if err != nil { + return nil, err + } + return &UsageInfo{ + PromptTokens: count, + TotalTokens: count, + }, nil +} + +type qwenProviderHooks struct{} + +func (qwenProviderHooks) beforeAttempt(attempt authAttempt) (int, []byte, string, bool) { + retryAfter, blocked := checkQwenRateLimit(qwenRateLimitTarget(attempt)) + if !blocked { + return 0, nil, "", false + } + secs := max(1, int(retryAfter.Seconds())) + body := []byte(fmt.Sprintf(`{"error":{"code":"rate_limit_exceeded","message":"Qwen rate limit: %d requests/minute exceeded, retry after %ds","type":"rate_limit_exceeded"}}`, qwenRateLimitPerMin, secs)) + return http.StatusTooManyRequests, body, "application/json", true +} + +func (qwenProviderHooks) endpoint(base *HTTPProvider, attempt authAttempt, path string) string { + return endpointFor(qwenBaseURLForAttempt(base, attempt), path) +} + +func (qwenProviderHooks) classifyFailure(status int, body []byte) (int, oauthFailureReason, bool, *time.Duration) { + return classifyQwenFailure(status, body) +} + +func (qwenProviderHooks) afterFailure(base *HTTPProvider, attempt authAttempt, reason oauthFailureReason, retryAfter *time.Duration) { + applyAttemptFailure(base, attempt, reason, retryAfter) +} + +func buildQwenChatRequest(base *HTTPProvider, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool) map[string]interface{} { + body := base.buildOpenAICompatChatRequest(messages, tools, qwenBaseModel(model), options) + if stream { + body["stream"] = true + body["stream_options"] = map[string]interface{}{"include_usage": true} + qwenInjectPoisonTool(body) + } + if suffix := qwenModelSuffix(model); suffix != "" { + applyQwenThinkingSuffix(body, suffix) + } + return body +} + +func qwenBaseModel(model string) string { + trimmed := strings.TrimSpace(model) + if trimmed == "" { + return trimmed + } + open := strings.LastIndex(trimmed, "(") + if open <= 0 || !strings.HasSuffix(trimmed, ")") { + return trimmed + } + suffix := strings.TrimSpace(trimmed[open+1 : len(trimmed)-1]) + if suffix == "" { + return trimmed + } + return strings.TrimSpace(trimmed[:open]) +} + +func qwenModelSuffix(model string) string { + trimmed := strings.TrimSpace(model) + open := strings.LastIndex(trimmed, "(") + if open <= 0 || !strings.HasSuffix(trimmed, ")") { + return "" + } + return strings.TrimSpace(trimmed[open+1 : len(trimmed)-1]) +} + +func applyQwenThinkingSuffix(body map[string]interface{}, suffix string) { + suffix = strings.TrimSpace(strings.ToLower(suffix)) + if suffix == "" { + return + } + switch suffix { + case "low", "medium", "high", "auto": + body["reasoning_effort"] = suffix + case "none": + delete(body, "reasoning_effort") + body["thinking"] = map[string]interface{}{"type": "disabled"} + default: + if n, err := strconv.Atoi(suffix); err == nil && n > 0 { + body["thinking"] = map[string]interface{}{ + "type": "enabled", + "budget_tokens": n, + } + } + } +} + +func qwenInjectPoisonTool(body map[string]interface{}) { + tools, ok := body["tools"].([]map[string]interface{}) + if ok && len(tools) > 0 { + return + } + body["tools"] = []map[string]interface{}{ + { + "type": "function", + "function": map[string]interface{}{ + "name": "do_not_call_me", + "description": "Do not call this tool under any circumstances, it will have catastrophic consequences.", + "parameters": map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "operation": map[string]interface{}{ + "type": "number", + "description": "1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1", + }, + }, + "required": []string{"operation"}, + }, + }, + }, + } +} + +func qwenRateLimitTarget(attempt authAttempt) string { + if attempt.session != nil { + return firstNonEmpty(strings.TrimSpace(attempt.session.FilePath), strings.TrimSpace(attempt.session.Email), strings.TrimSpace(attempt.session.AccountID)) + } + return strings.TrimSpace(attempt.token) +} + +func checkQwenRateLimit(target string) (time.Duration, bool) { + if strings.TrimSpace(target) == "" { + return 0, false + } + now := time.Now() + windowStart := now.Add(-time.Minute) + + qwenRateLimiter.Lock() + defer qwenRateLimiter.Unlock() + + var valid []time.Time + for _, ts := range qwenRateLimiter.requests[target] { + if ts.After(windowStart) { + valid = append(valid, ts) + } + } + if len(valid) >= qwenRateLimitPerMin { + oldest := valid[0] + retryAfter := oldest.Add(time.Minute).Sub(now) + if retryAfter < time.Second { + retryAfter = time.Second + } + qwenRateLimiter.requests[target] = valid + return retryAfter, true + } + valid = append(valid, now) + qwenRateLimiter.requests[target] = valid + return 0, false +} + +func classifyQwenFailure(status int, body []byte) (int, oauthFailureReason, bool, *time.Duration) { + if status != http.StatusForbidden && status != http.StatusTooManyRequests && status != http.StatusPaymentRequired { + return status, "", false, nil + } + lower := strings.ToLower(string(body)) + code := strings.ToLower(extractJSONErrorField(body, "code")) + errType := strings.ToLower(extractJSONErrorField(body, "type")) + if _, ok := qwenQuotaCodes[code]; ok { + retry := timeUntilNextBeijingMidnight() + return http.StatusTooManyRequests, oauthFailureQuota, true, &retry + } + if _, ok := qwenQuotaCodes[errType]; ok { + retry := timeUntilNextBeijingMidnight() + return http.StatusTooManyRequests, oauthFailureQuota, true, &retry + } + if strings.Contains(lower, "free allocated quota exceeded") || strings.Contains(lower, "quota exceeded") || strings.Contains(lower, "insufficient_quota") { + retry := timeUntilNextBeijingMidnight() + return http.StatusTooManyRequests, oauthFailureQuota, true, &retry + } + reason, retry := classifyOAuthFailure(status, body) + return status, reason, retry, nil +} + +func extractJSONErrorField(body []byte, field string) string { + var payload map[string]interface{} + if err := json.Unmarshal(body, &payload); err != nil { + return "" + } + errObj, _ := payload["error"].(map[string]interface{}) + if errObj == nil { + return "" + } + return strings.TrimSpace(fmt.Sprintf("%v", errObj[field])) +} + +func timeUntilNextBeijingMidnight() time.Duration { + now := time.Now() + local := now.In(qwenBeijingLocation) + next := time.Date(local.Year(), local.Month(), local.Day()+1, 0, 0, 0, 0, qwenBeijingLocation) + return next.Sub(now) +} + +func qwenBaseURLForAttempt(base *HTTPProvider, attempt authAttempt) string { + if attempt.session != nil { + if resource := strings.TrimSpace(attempt.session.ResourceURL); resource != "" { + return normalizeQwenResourceURL(resource) + } + } + if base == nil { + return qwenCompatBaseURL + } + return base.compatBase() +} + +func normalizeQwenResourceURL(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return qwenCompatBaseURL + } + lower := strings.ToLower(trimmed) + switch { + case strings.HasSuffix(lower, "/v1"): + if strings.HasPrefix(lower, "http://") || strings.HasPrefix(lower, "https://") { + return normalizeAPIBase(trimmed) + } + return normalizeAPIBase("https://" + trimmed) + case strings.HasSuffix(lower, "/api"): + base := trimmed[:len(trimmed)-4] + "/v1" + if strings.HasPrefix(strings.ToLower(base), "http://") || strings.HasPrefix(strings.ToLower(base), "https://") { + return normalizeAPIBase(base) + } + return normalizeAPIBase("https://" + base) + case strings.HasPrefix(lower, "http://"), strings.HasPrefix(lower, "https://"): + return normalizeAPIBase(trimmed + "/v1") + default: + return normalizeAPIBase("https://" + trimmed + "/v1") + } +} diff --git a/pkg/providers/vertex_provider.go b/pkg/providers/vertex_provider.go new file mode 100644 index 0000000..624259a --- /dev/null +++ b/pkg/providers/vertex_provider.go @@ -0,0 +1,512 @@ +package providers + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/jwt" +) + +const ( + vertexDefaultBaseURL = "https://aiplatform.googleapis.com" + vertexAPIVersion = "v1" + vertexDefaultRegion = "us-central1" +) + +type VertexProvider struct { + base *HTTPProvider +} + +func NewVertexProvider(providerName, apiKey, apiBase, defaultModel string, supportsResponsesCompact bool, authMode string, timeout time.Duration, oauth *oauthManager) *VertexProvider { + normalizedBase := normalizeAPIBase(apiBase) + if normalizedBase == "" { + normalizedBase = vertexDefaultBaseURL + } + return &VertexProvider{ + base: NewHTTPProvider(providerName, apiKey, normalizedBase, defaultModel, supportsResponsesCompact, authMode, timeout, oauth), + } +} + +func (p *VertexProvider) GetDefaultModel() string { + if p == nil || p.base == nil { + return "" + } + return p.base.GetDefaultModel() +} + +func (p *VertexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, false, nil) + if err != nil { + return nil, err + } + if status != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body)) + } + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body)) + } + if isVertexImagenModel(model) { + body = convertVertexImagenToGeminiResponse(body, strings.TrimSpace(qwenBaseModel(model))) + } + return parseGeminiResponse(body) +} + +func (p *VertexProvider) ChatStream(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, onDelta func(string)) (*LLMResponse, error) { + body, status, ctype, err := p.doRequest(ctx, messages, tools, model, options, true, onDelta) + if err != nil { + return nil, err + } + if status != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body)) + } + if !json.Valid(body) { + return nil, fmt.Errorf("API error (status %d, content-type %q): non-JSON response: %s", status, ctype, previewResponseBody(body)) + } + if isVertexImagenModel(model) { + body = convertVertexImagenToGeminiResponse(body, strings.TrimSpace(qwenBaseModel(model))) + } + return parseGeminiResponse(body) +} + +func (p *VertexProvider) CountTokens(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*UsageInfo, error) { + if p == nil || p.base == nil { + return nil, fmt.Errorf("provider not configured") + } + attempts, err := p.vertexAttempts(ctx) + if err != nil { + return nil, err + } + var lastBody []byte + var lastStatus int + var lastType string + for _, attempt := range attempts { + requestBody := p.buildRequestBody(messages, nil, model, options) + if isVertexImagenModel(model) { + requestBody, err = convertVertexImagenRequest(requestBody) + if err != nil { + return nil, err + } + } + delete(requestBody, "tools") + delete(requestBody, "toolConfig") + delete(requestBody, "generationConfig") + endpoint := p.endpoint(attempt, model, "countTokens", false, options) + body, status, ctype, reqErr := p.performAttempt(ctx, endpoint, requestBody, attempt, false, nil) + if reqErr != nil { + return nil, reqErr + } + lastBody, lastStatus, lastType = body, status, ctype + reason, retry := classifyOAuthFailure(status, body) + if retry { + applyAttemptFailure(p.base, attempt, reason, nil) + continue + } + if status != http.StatusOK { + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", status, ctype, previewResponseBody(body)) + } + var payload struct { + TotalTokens int `json:"totalTokens"` + } + if err := json.Unmarshal(body, &payload); err != nil { + return nil, fmt.Errorf("invalid countTokens response: %w", err) + } + p.base.markAttemptSuccess(attempt) + return &UsageInfo{PromptTokens: payload.TotalTokens, TotalTokens: payload.TotalTokens}, nil + } + return nil, fmt.Errorf("API error (status %d, content-type %q): %s", lastStatus, lastType, previewResponseBody(lastBody)) +} + +func (p *VertexProvider) doRequest(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}, stream bool, onDelta func(string)) ([]byte, int, string, error) { + if p == nil || p.base == nil { + return nil, 0, "", fmt.Errorf("provider not configured") + } + attempts, err := p.vertexAttempts(ctx) + if err != nil { + return nil, 0, "", err + } + var lastBody []byte + var lastStatus int + var lastType string + for _, attempt := range attempts { + requestBody := p.buildRequestBody(messages, tools, model, options) + if isVertexImagenModel(model) { + requestBody, err = convertVertexImagenRequest(requestBody) + if err != nil { + return nil, 0, "", err + } + } + endpoint := p.endpoint(attempt, model, "generateContent", stream, options) + body, status, ctype, reqErr := p.performAttempt(ctx, endpoint, requestBody, attempt, stream, onDelta) + if reqErr != nil { + return nil, 0, "", reqErr + } + lastBody, lastStatus, lastType = body, status, ctype + reason, retry := classifyOAuthFailure(status, body) + if retry { + applyAttemptFailure(p.base, attempt, reason, nil) + continue + } + p.base.markAttemptSuccess(attempt) + return body, status, ctype, nil + } + return lastBody, lastStatus, lastType, nil +} + +func (p *VertexProvider) performAttempt(ctx context.Context, endpoint string, payload map[string]any, attempt authAttempt, stream bool, onDelta func(string)) ([]byte, int, string, error) { + reqBody, err := json.Marshal(payload) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to marshal request: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(reqBody)) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + if stream { + req.Header.Set("Accept", "text/event-stream") + } else { + req.Header.Set("Accept", "application/json") + } + if err := p.applyVertexAttemptAuth(ctx, req, attempt); err != nil { + return nil, 0, "", err + } + client, err := p.base.httpClientForAttempt(attempt) + if err != nil { + return nil, 0, "", err + } + resp, err := client.Do(req) + if err != nil { + return nil, 0, "", fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + ctype := strings.TrimSpace(resp.Header.Get("Content-Type")) + if stream && strings.Contains(strings.ToLower(ctype), "text/event-stream") { + return consumeGeminiStream(resp, onDelta) + } + body, readErr := ioReadAll(resp) + if readErr != nil { + return nil, resp.StatusCode, ctype, readErr + } + return body, resp.StatusCode, ctype, nil +} + +func (p *VertexProvider) buildRequestBody(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) map[string]any { + return (&GeminiProvider{base: p.base}).buildRequestBody(messages, tools, model, options, false) +} + +func (p *VertexProvider) endpoint(attempt authAttempt, model, action string, stream bool, options map[string]interface{}) string { + base := vertexBaseURLForAttempt(p.base, attempt, options) + baseModel := strings.TrimSpace(qwenBaseModel(model)) + vertexAction := vertexAction(baseModel, stream) + if projectID, location, ok := vertexProjectLocation(attempt, options); ok { + if stream { + if isVertexImagenModel(baseModel) { + return fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", base, vertexAPIVersion, projectID, location, baseModel, vertexAction) + } + return fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s?alt=sse", base, vertexAPIVersion, projectID, location, baseModel, vertexAction) + } + return fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", base, vertexAPIVersion, projectID, location, baseModel, vertexAction) + } + if stream { + if isVertexImagenModel(baseModel) { + return fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", base, vertexAPIVersion, baseModel, vertexAction) + } + return fmt.Sprintf("%s/%s/publishers/google/models/%s:%s?alt=sse", base, vertexAPIVersion, baseModel, vertexAction) + } + return fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", base, vertexAPIVersion, baseModel, vertexAction) +} + +func vertexBaseURLForAttempt(base *HTTPProvider, attempt authAttempt, options map[string]interface{}) string { + customBase := "" + if attempt.session != nil && attempt.session.Token != nil { + if raw := strings.TrimSpace(asString(attempt.session.Token["base_url"])); raw != "" { + customBase = normalizeVertexBaseURL(raw) + } + } + if customBase == "" && base != nil && strings.TrimSpace(base.apiBase) != "" && !strings.Contains(strings.ToLower(base.apiBase), "api.openai.com") { + customBase = normalizeVertexBaseURL(base.apiBase) + } + if customBase != "" && !strings.EqualFold(customBase, vertexDefaultBaseURL) { + return customBase + } + location := vertexLocationForAttempt(attempt, options) + if location != "" && !strings.EqualFold(location, "global") { + return fmt.Sprintf("https://%s-aiplatform.googleapis.com", location) + } + return vertexDefaultBaseURL +} + +func normalizeVertexBaseURL(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return vertexDefaultBaseURL + } + if !strings.Contains(trimmed, "://") { + trimmed = "https://" + trimmed + } + return normalizeAPIBase(trimmed) +} + +func vertexProjectLocation(attempt authAttempt, options map[string]interface{}) (string, string, bool) { + projectID := "" + if value, ok := stringOption(options, "vertex_project_id"); ok { + projectID = strings.TrimSpace(value) + } + if attempt.session != nil { + projectID = firstNonEmpty(projectID, strings.TrimSpace(attempt.session.ProjectID), asString(attempt.session.Token["project_id"]), asString(attempt.session.Token["projectId"]), asString(attempt.session.Token["project"])) + if projectID == "" { + projectID = strings.TrimSpace(asString(mapFromAny(attempt.session.Token["service_account"])["project_id"])) + } + } + location := vertexLocationForAttempt(attempt, options) + if strings.TrimSpace(projectID) == "" { + return "", "", false + } + return projectID, location, true +} + +func vertexLocationForAttempt(attempt authAttempt, options map[string]interface{}) string { + location := "" + if value, ok := stringOption(options, "vertex_location"); ok { + location = strings.TrimSpace(value) + } + if attempt.session != nil { + location = firstNonEmpty(location, asString(attempt.session.Token["location"]), asString(mapFromAny(attempt.session.Token["service_account"])["location"])) + } + if strings.TrimSpace(location) == "" { + location = vertexDefaultRegion + } + return location +} + +func isVertexImagenModel(model string) bool { + return strings.Contains(strings.ToLower(strings.TrimSpace(model)), "imagen") +} + +func vertexAction(model string, stream bool) string { + if isVertexImagenModel(model) { + return "predict" + } + if stream { + return "streamGenerateContent" + } + return "generateContent" +} + +func convertVertexImagenRequest(payload map[string]any) (map[string]any, error) { + prompt := strings.TrimSpace(geminiPromptTextFromPayload(payload)) + if prompt == "" { + return nil, fmt.Errorf("imagen: no prompt found in request") + } + request := map[string]any{ + "instances": []map[string]any{{"prompt": prompt}}, + "parameters": map[string]any{ + "sampleCount": 1, + }, + } + if generationConfig := mapFromAny(payload["generationConfig"]); len(generationConfig) > 0 { + if aspectRatio := strings.TrimSpace(asString(generationConfig["aspectRatio"])); aspectRatio != "" { + request["parameters"].(map[string]any)["aspectRatio"] = aspectRatio + } + switch count := generationConfig["sampleCount"].(type) { + case int: + if count > 0 { + request["parameters"].(map[string]any)["sampleCount"] = count + } + case float64: + if int(count) > 0 { + request["parameters"].(map[string]any)["sampleCount"] = int(count) + } + } + if negativePrompt := strings.TrimSpace(asString(generationConfig["negativePrompt"])); negativePrompt != "" { + request["instances"].([]map[string]any)[0]["negativePrompt"] = negativePrompt + } + } + return request, nil +} + +func geminiPromptTextFromPayload(payload map[string]any) string { + contents, _ := payload["contents"].([]map[string]any) + for _, content := range contents { + parts, _ := content["parts"].([]map[string]any) + for _, part := range parts { + if text := strings.TrimSpace(asString(part["text"])); text != "" { + return text + } + } + } + return strings.TrimSpace(asString(payload["prompt"])) +} + +func convertVertexImagenToGeminiResponse(data []byte, model string) []byte { + var raw struct { + Predictions []struct { + BytesBase64Encoded string `json:"bytesBase64Encoded"` + MIMEType string `json:"mimeType"` + } `json:"predictions"` + } + if err := json.Unmarshal(data, &raw); err != nil || len(raw.Predictions) == 0 { + return data + } + parts := make([]map[string]any, 0, len(raw.Predictions)) + for _, prediction := range raw.Predictions { + if strings.TrimSpace(prediction.BytesBase64Encoded) == "" { + continue + } + parts = append(parts, map[string]any{ + "inlineData": map[string]any{ + "mimeType": firstNonEmpty(strings.TrimSpace(prediction.MIMEType), "image/png"), + "data": prediction.BytesBase64Encoded, + }, + }) + } + if len(parts) == 0 { + return data + } + converted := map[string]any{ + "candidates": []map[string]any{{ + "content": map[string]any{ + "parts": parts, + "role": "model", + }, + "finishReason": "STOP", + }}, + "responseId": fmt.Sprintf("imagen-%d", time.Now().UnixNano()), + "modelVersion": model, + "usageMetadata": map[string]any{ + "promptTokenCount": 0, + "candidatesTokenCount": 0, + "totalTokenCount": 0, + }, + } + out, err := json.Marshal(converted) + if err != nil { + return data + } + return out +} + +func ioReadAll(resp *http.Response) ([]byte, error) { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + return body, nil +} + +func (p *VertexProvider) vertexAttempts(ctx context.Context) ([]authAttempt, error) { + if p == nil || p.base == nil { + return nil, fmt.Errorf("provider not configured") + } + mode := strings.ToLower(strings.TrimSpace(p.base.authMode)) + if mode != "oauth" && mode != "hybrid" { + return p.base.authAttempts(ctx) + } + out := make([]authAttempt, 0, 4) + if mode == "hybrid" { + if apiAttempt, apiReady := p.base.apiKeyAttempt(); apiReady { + out = append(out, apiAttempt) + } + } + if p.base.oauth == nil { + if len(out) > 0 { + p.base.updateCandidateOrder(out) + return out, nil + } + return nil, fmt.Errorf("oauth is enabled but provider session manager is not configured") + } + manager := p.base.oauth + manager.mu.Lock() + sessions, err := manager.loadAllLocked() + manager.mu.Unlock() + if err != nil { + return nil, err + } + for _, session := range sessions { + if session == nil || manager.sessionOnCooldown(session) { + continue + } + token := strings.TrimSpace(session.AccessToken) + if token == "" && !vertexHasServiceAccount(session) { + continue + } + out = append(out, authAttempt{session: session, token: token, kind: "oauth"}) + } + if len(out) == 0 { + return nil, fmt.Errorf("oauth session not found, run `clawgo provider login` first") + } + p.base.updateCandidateOrder(out) + return out, nil +} + +func (p *VertexProvider) applyVertexAttemptAuth(ctx context.Context, req *http.Request, attempt authAttempt) error { + if req == nil { + return nil + } + if attempt.kind == "api_key" { + applyGeminiAttemptAuth(req, attempt) + return nil + } + if strings.TrimSpace(attempt.token) != "" { + applyGeminiAttemptAuth(req, attempt) + return nil + } + saJSON, err := vertexServiceAccountJSON(attempt.session) + if err != nil { + return err + } + client, err := p.base.httpClientForAttempt(attempt) + if err == nil && client != nil { + ctx = context.WithValue(ctx, oauth2.HTTPClient, client) + } + var serviceAccount struct { + ClientEmail string `json:"client_email"` + PrivateKey string `json:"private_key"` + TokenURI string `json:"token_uri"` + } + if err := json.Unmarshal(saJSON, &serviceAccount); err != nil { + return fmt.Errorf("vertex service account parse failed: %w", err) + } + cfg := &jwt.Config{ + Email: strings.TrimSpace(serviceAccount.ClientEmail), + PrivateKey: []byte(serviceAccount.PrivateKey), + TokenURL: strings.TrimSpace(serviceAccount.TokenURI), + Scopes: []string{"https://www.googleapis.com/auth/cloud-platform"}, + } + tok, err := cfg.TokenSource(ctx).Token() + if err != nil { + return fmt.Errorf("vertex service account token failed: %w", err) + } + req.Header.Del("x-goog-api-key") + req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(tok.AccessToken)) + return nil +} + +func vertexHasServiceAccount(session *oauthSession) bool { + if session == nil || session.Token == nil { + return false + } + return len(mapFromAny(session.Token["service_account"])) > 0 +} + +func vertexServiceAccountJSON(session *oauthSession) ([]byte, error) { + if !vertexHasServiceAccount(session) { + return nil, fmt.Errorf("vertex service account missing") + } + raw := mapFromAny(session.Token["service_account"]) + if projectID := firstNonEmpty(asString(raw["project_id"]), strings.TrimSpace(session.ProjectID), asString(session.Token["project_id"]), asString(session.Token["project"])); projectID != "" { + raw["project_id"] = projectID + } + data, err := json.Marshal(raw) + if err != nil { + return nil, fmt.Errorf("vertex service account marshal failed: %w", err) + } + return data, nil +} diff --git a/pkg/providers/vertex_provider_test.go b/pkg/providers/vertex_provider_test.go new file mode 100644 index 0000000..4afb52e --- /dev/null +++ b/pkg/providers/vertex_provider_test.go @@ -0,0 +1,243 @@ +package providers + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/YspCoder/clawgo/pkg/config" +) + +func TestCreateProviderByNameRoutesVertexProvider(t *testing.T) { + cfg := &config.Config{ + Models: config.ModelsConfig{ + Providers: map[string]config.ProviderConfig{ + "vertex": { + TimeoutSec: 30, + Models: []string{"gemini-2.5-pro"}, + }, + }, + }, + } + provider, err := CreateProviderByName(cfg, "vertex") + if err != nil { + t.Fatalf("CreateProviderByName error: %v", err) + } + if _, ok := provider.(*VertexProvider); !ok { + t.Fatalf("provider = %T, want *VertexProvider", provider) + } +} + +func TestVertexEndpointWithProjectLocation(t *testing.T) { + p := NewVertexProvider("vertex", "", "", "gemini-2.5-pro", false, "oauth", 5*time.Second, nil) + endpoint := p.endpoint(authAttempt{ + kind: "oauth", + session: &oauthSession{ + ProjectID: "demo-project", + Token: map[string]any{ + "location": "asia-east1", + }, + }, + }, "gemini-2.5-pro", "generateContent", false, nil) + want := "https://asia-east1-aiplatform.googleapis.com/v1/projects/demo-project/locations/asia-east1/publishers/google/models/gemini-2.5-pro:generateContent" + if endpoint != want { + t.Fatalf("endpoint = %q, want %q", endpoint, want) + } +} + +func TestVertexEndpointUsesPredictForImagen(t *testing.T) { + p := NewVertexProvider("vertex", "", "", "imagen-4.0-generate-001", false, "oauth", 5*time.Second, nil) + endpoint := p.endpoint(authAttempt{ + kind: "oauth", + session: &oauthSession{ + ProjectID: "demo-project", + Token: map[string]any{ + "location": "us-central1", + }, + }, + }, "imagen-4.0-generate-001", "generateContent", false, nil) + want := "https://us-central1-aiplatform.googleapis.com/v1/projects/demo-project/locations/us-central1/publishers/google/models/imagen-4.0-generate-001:predict" + if endpoint != want { + t.Fatalf("endpoint = %q, want %q", endpoint, want) + } +} + +func TestVertexEndpointUsesGlobalBaseForGlobalLocation(t *testing.T) { + p := NewVertexProvider("vertex", "", "", "gemini-2.5-pro", false, "oauth", 5*time.Second, nil) + endpoint := p.endpoint(authAttempt{ + kind: "oauth", + session: &oauthSession{ + ProjectID: "demo-project", + Token: map[string]any{ + "location": "global", + }, + }, + }, "gemini-2.5-pro", "generateContent", false, nil) + want := "https://aiplatform.googleapis.com/v1/projects/demo-project/locations/global/publishers/google/models/gemini-2.5-pro:generateContent" + if endpoint != want { + t.Fatalf("endpoint = %q, want %q", endpoint, want) + } +} + +func TestConvertVertexImagenRequest(t *testing.T) { + payload := map[string]any{ + "contents": []map[string]any{ + {"parts": []map[string]any{{"text": "draw a cat"}}}, + }, + "generationConfig": map[string]any{ + "aspectRatio": "1:1", + "sampleCount": 2.0, + "negativePrompt": "blurry", + }, + } + req, err := convertVertexImagenRequest(payload) + if err != nil { + t.Fatalf("convertVertexImagenRequest error: %v", err) + } + instances, _ := req["instances"].([]map[string]any) + if len(instances) != 1 || instances[0]["prompt"] != "draw a cat" { + t.Fatalf("unexpected instances: %+v", instances) + } + params, _ := req["parameters"].(map[string]any) + if params["aspectRatio"] != "1:1" || params["sampleCount"] != 2 { + t.Fatalf("unexpected parameters: %+v", params) + } +} + +func TestConvertVertexImagenToGeminiResponse(t *testing.T) { + body := []byte(`{"predictions":[{"bytesBase64Encoded":"abcd","mimeType":"image/png"}]}`) + converted := convertVertexImagenToGeminiResponse(body, "imagen-4.0-generate-001") + resp, err := parseGeminiResponse(converted) + if err != nil { + t.Fatalf("parseGeminiResponse error: %v", err) + } + if resp.FinishReason != "STOP" { + t.Fatalf("finish reason = %q", resp.FinishReason) + } +} + +func TestVertexProjectLocationFallsBackToProjectAlias(t *testing.T) { + projectID, location, ok := vertexProjectLocation(authAttempt{ + kind: "oauth", + session: &oauthSession{ + Token: map[string]any{ + "project": "demo-project", + "location": "asia-east1", + }, + }, + }, nil) + if !ok || projectID != "demo-project" || location != "asia-east1" { + t.Fatalf("unexpected project/location: ok=%v project=%q location=%q", ok, projectID, location) + } +} + +func TestVertexServiceAccountJSONBackfillsProjectID(t *testing.T) { + session := &oauthSession{ + ProjectID: "demo-project", + Token: map[string]any{ + "service_account": map[string]any{ + "type": "service_account", + "client_email": "svc@example.com", + "private_key": "key", + "token_uri": "https://example.com/token", + }, + }, + } + data, err := vertexServiceAccountJSON(session) + if err != nil { + t.Fatalf("vertexServiceAccountJSON error: %v", err) + } + if !strings.Contains(string(data), `"project_id":"demo-project"`) { + t.Fatalf("expected project_id in service account json, got %s", string(data)) + } +} + +func TestVertexProviderChatWithAPIKey(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/publishers/google/models/gemini-2.5-pro:generateContent" { + http.NotFound(w, r) + return + } + if got := r.Header.Get("x-goog-api-key"); got != "token" { + t.Fatalf("x-goog-api-key = %q", got) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"candidates":[{"content":{"parts":[{"text":"ok"}]},"finishReason":"STOP"}]}`)) + })) + defer server.Close() + + p := NewVertexProvider("vertex", "token", server.URL, "gemini-2.5-pro", false, "api_key", 5*time.Second, nil) + resp, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "gemini-2.5-pro", nil) + if err != nil { + t.Fatalf("Chat error: %v", err) + } + if resp.Content != "ok" { + t.Fatalf("content = %q, want ok", resp.Content) + } +} + +func TestVertexProviderChatWithServiceAccount(t *testing.T) { + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"access_token":"sa-token","token_type":"Bearer","expires_in":3600}`)) + })) + defer tokenServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/projects/demo-project/locations/us-central1/publishers/google/models/gemini-2.5-pro:generateContent" { + http.NotFound(w, r) + return + } + if got := r.Header.Get("Authorization"); got != "Bearer sa-token" { + t.Fatalf("Authorization = %q", got) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"candidates":[{"content":{"parts":[{"text":"ok"}]},"finishReason":"STOP"}]}`)) + })) + defer apiServer.Close() + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("rsa.GenerateKey: %v", err) + } + pemKey := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}) + + p := NewVertexProvider("vertex", "", apiServer.URL, "gemini-2.5-pro", false, "oauth", 5*time.Second, nil) + p.base.oauth = &oauthManager{} + p.base.authMode = "oauth" + p.base.oauth = &oauthManager{ + cfg: oauthConfig{}, + cached: []*oauthSession{{ + ProjectID: "demo-project", + Token: map[string]any{ + "location": "us-central1", + "service_account": map[string]any{ + "type": "service_account", + "project_id": "demo-project", + "private_key_id": "key-1", + "private_key": string(pemKey), + "client_email": "svc@example.iam.gserviceaccount.com", + "client_id": "1234567890", + "token_uri": tokenServer.URL, + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://example.com/cert", + }, + }, + }}, + } + + resp, err := p.Chat(t.Context(), []Message{{Role: "user", Content: "hello"}}, nil, "gemini-2.5-pro", nil) + if err != nil { + t.Fatalf("Chat error: %v", err) + } + if resp.Content != "ok" { + t.Fatalf("content = %q, want ok", resp.Content) + } +} diff --git a/pkg/wsrelay/http.go b/pkg/wsrelay/http.go new file mode 100644 index 0000000..e2205c0 --- /dev/null +++ b/pkg/wsrelay/http.go @@ -0,0 +1,241 @@ +package wsrelay + +import ( + "bytes" + "context" + "errors" + "fmt" + "net/http" + "time" + + "github.com/google/uuid" +) + +type HTTPRequest struct { + Method string + URL string + Headers http.Header + Body []byte +} + +type HTTPResponse struct { + Status int + Headers http.Header + Body []byte +} + +type StreamEvent struct { + Type string + Payload []byte + Status int + Headers http.Header + Err error +} + +func (m *Manager) NonStream(ctx context.Context, provider string, req *HTTPRequest) (*HTTPResponse, error) { + if req == nil { + return nil, fmt.Errorf("wsrelay: request is nil") + } + msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)} + respCh, err := m.Send(ctx, provider, msg) + if err != nil { + return nil, err + } + var ( + streamMode bool + streamResp *HTTPResponse + streamBody bytes.Buffer + ) + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case msg, ok := <-respCh: + if !ok { + if streamMode { + if streamResp == nil { + streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)} + } else if streamResp.Headers == nil { + streamResp.Headers = make(http.Header) + } + streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...) + return streamResp, nil + } + return nil, errors.New("wsrelay: connection closed during response") + } + switch msg.Type { + case MessageTypeHTTPResp: + resp := decodeResponse(msg.Payload) + if streamMode && streamBody.Len() > 0 && len(resp.Body) == 0 { + resp.Body = append(resp.Body[:0], streamBody.Bytes()...) + } + return resp, nil + case MessageTypeError: + return nil, decodeError(msg.Payload) + case MessageTypeStreamStart, MessageTypeStreamChunk: + if msg.Type == MessageTypeStreamStart { + streamMode = true + streamResp = decodeResponse(msg.Payload) + if streamResp.Headers == nil { + streamResp.Headers = make(http.Header) + } + streamBody.Reset() + continue + } + if !streamMode { + streamMode = true + streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)} + } + chunk := decodeChunk(msg.Payload) + if len(chunk) > 0 { + streamBody.Write(chunk) + } + case MessageTypeStreamEnd: + if !streamMode { + return &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}, nil + } + if streamResp == nil { + streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)} + } else if streamResp.Headers == nil { + streamResp.Headers = make(http.Header) + } + streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...) + return streamResp, nil + } + } + } +} + +func (m *Manager) Stream(ctx context.Context, provider string, req *HTTPRequest) (<-chan StreamEvent, error) { + if req == nil { + return nil, fmt.Errorf("wsrelay: request is nil") + } + msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)} + respCh, err := m.Send(ctx, provider, msg) + if err != nil { + return nil, err + } + out := make(chan StreamEvent) + go func() { + defer close(out) + send := func(ev StreamEvent) bool { + if ctx == nil { + out <- ev + return true + } + select { + case <-ctx.Done(): + return false + case out <- ev: + return true + } + } + for { + select { + case <-ctx.Done(): + return + case msg, ok := <-respCh: + if !ok { + _ = send(StreamEvent{Err: errors.New("wsrelay: stream closed")}) + return + } + switch msg.Type { + case MessageTypeStreamStart: + resp := decodeResponse(msg.Payload) + if ok := send(StreamEvent{Type: MessageTypeStreamStart, Status: resp.Status, Headers: resp.Headers}); !ok { + return + } + case MessageTypeStreamChunk: + chunk := decodeChunk(msg.Payload) + if ok := send(StreamEvent{Type: MessageTypeStreamChunk, Payload: chunk}); !ok { + return + } + case MessageTypeStreamEnd: + _ = send(StreamEvent{Type: MessageTypeStreamEnd}) + return + case MessageTypeError: + _ = send(StreamEvent{Type: MessageTypeError, Err: decodeError(msg.Payload)}) + return + case MessageTypeHTTPResp: + resp := decodeResponse(msg.Payload) + _ = send(StreamEvent{Type: MessageTypeHTTPResp, Status: resp.Status, Headers: resp.Headers, Payload: resp.Body}) + return + } + } + } + }() + return out, nil +} + +func encodeRequest(req *HTTPRequest) map[string]any { + headers := make(map[string]any, len(req.Headers)) + for key, values := range req.Headers { + copyValues := make([]string, len(values)) + copy(copyValues, values) + headers[key] = copyValues + } + return map[string]any{ + "method": req.Method, + "url": req.URL, + "headers": headers, + "body": string(req.Body), + "sent_at": time.Now().UTC().Format(time.RFC3339Nano), + } +} + +func decodeResponse(payload map[string]any) *HTTPResponse { + if payload == nil { + return &HTTPResponse{Status: http.StatusBadGateway, Headers: make(http.Header)} + } + resp := &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)} + if status, ok := payload["status"].(float64); ok { + resp.Status = int(status) + } + if headers, ok := payload["headers"].(map[string]any); ok { + for key, raw := range headers { + switch v := raw.(type) { + case []any: + for _, item := range v { + if str, ok := item.(string); ok { + resp.Headers.Add(key, str) + } + } + case []string: + for _, str := range v { + resp.Headers.Add(key, str) + } + case string: + resp.Headers.Set(key, v) + } + } + } + if body, ok := payload["body"].(string); ok { + resp.Body = []byte(body) + } + return resp +} + +func decodeChunk(payload map[string]any) []byte { + if payload == nil { + return nil + } + if data, ok := payload["data"].(string); ok { + return []byte(data) + } + return nil +} + +func decodeError(payload map[string]any) error { + if payload == nil { + return errors.New("wsrelay: unknown error") + } + message, _ := payload["error"].(string) + status := 0 + if v, ok := payload["status"].(float64); ok { + status = int(v) + } + if message == "" { + message = "wsrelay: upstream error" + } + return fmt.Errorf("%s (status=%d)", message, status) +} diff --git a/pkg/wsrelay/manager.go b/pkg/wsrelay/manager.go new file mode 100644 index 0000000..4d31624 --- /dev/null +++ b/pkg/wsrelay/manager.go @@ -0,0 +1,192 @@ +package wsrelay + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type Manager struct { + path string + upgrader websocket.Upgrader + sessions map[string]*session + sessMutex sync.RWMutex + + providerFactory func(*http.Request) (string, error) + onConnected func(string) + onDisconnected func(string, error) + + logDebugf func(string, ...any) + logInfof func(string, ...any) + logWarnf func(string, ...any) +} + +type Options struct { + Path string + ProviderFactory func(*http.Request) (string, error) + OnConnected func(string) + OnDisconnected func(string, error) + LogDebugf func(string, ...any) + LogInfof func(string, ...any) + LogWarnf func(string, ...any) +} + +func NewManager(opts Options) *Manager { + path := strings.TrimSpace(opts.Path) + if path == "" { + path = "/v1/ws" + } + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + mgr := &Manager{ + path: path, + sessions: make(map[string]*session), + upgrader: websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(*http.Request) bool { return true }, + }, + providerFactory: opts.ProviderFactory, + onConnected: opts.OnConnected, + onDisconnected: opts.OnDisconnected, + logDebugf: opts.LogDebugf, + logInfof: opts.LogInfof, + logWarnf: opts.LogWarnf, + } + if mgr.logDebugf == nil { + mgr.logDebugf = func(string, ...any) {} + } + if mgr.logInfof == nil { + mgr.logInfof = func(string, ...any) {} + } + if mgr.logWarnf == nil { + mgr.logWarnf = func(string, ...any) {} + } + return mgr +} + +func (m *Manager) Path() string { + if m == nil { + return "/v1/ws" + } + return m.path +} + +func (m *Manager) Handler() http.Handler { + return http.HandlerFunc(m.handleWebsocket) +} + +func (m *Manager) Stop(_ context.Context) error { + m.sessMutex.Lock() + sessions := make([]*session, 0, len(m.sessions)) + for _, sess := range m.sessions { + sessions = append(sessions, sess) + } + m.sessions = make(map[string]*session) + m.sessMutex.Unlock() + + for _, sess := range sessions { + if sess != nil { + sess.cleanup(errors.New("wsrelay: manager stopped")) + } + } + return nil +} + +func (m *Manager) handleWebsocket(w http.ResponseWriter, r *http.Request) { + expectedPath := m.Path() + if expectedPath != "" && r.URL != nil && r.URL.Path != expectedPath { + http.NotFound(w, r) + return + } + if !strings.EqualFold(r.Method, http.MethodGet) { + w.Header().Set("Allow", http.MethodGet) + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + conn, err := m.upgrader.Upgrade(w, r, nil) + if err != nil { + m.logWarnf("wsrelay: upgrade failed: %v", err) + return + } + s := newSession(conn, m, randomProviderName()) + if m.providerFactory != nil { + name, err := m.providerFactory(r) + if err != nil { + s.cleanup(err) + return + } + if strings.TrimSpace(name) != "" { + s.provider = strings.ToLower(strings.TrimSpace(name)) + } + } + if s.provider == "" { + s.provider = strings.ToLower(s.id) + } + m.sessMutex.Lock() + var replaced *session + if existing, ok := m.sessions[s.provider]; ok { + replaced = existing + } + m.sessions[s.provider] = s + m.sessMutex.Unlock() + + if replaced != nil { + replaced.cleanup(errors.New("replaced by new connection")) + } + if m.onConnected != nil { + m.onConnected(s.provider) + } + go s.run(context.Background()) +} + +func (m *Manager) Send(ctx context.Context, provider string, msg Message) (<-chan Message, error) { + s := m.session(provider) + if s == nil { + return nil, fmt.Errorf("wsrelay: provider %s not connected", provider) + } + return s.request(ctx, msg) +} + +func (m *Manager) session(provider string) *session { + key := strings.ToLower(strings.TrimSpace(provider)) + m.sessMutex.RLock() + s := m.sessions[key] + m.sessMutex.RUnlock() + return s +} + +func (m *Manager) handleSessionClosed(s *session, cause error) { + if s == nil { + return + } + key := strings.ToLower(strings.TrimSpace(s.provider)) + m.sessMutex.Lock() + if cur, ok := m.sessions[key]; ok && cur == s { + delete(m.sessions, key) + } + m.sessMutex.Unlock() + if m.onDisconnected != nil { + m.onDisconnected(s.provider, cause) + } +} + +func randomProviderName() string { + const alphabet = "abcdefghijklmnopqrstuvwxyz0123456789" + buf := make([]byte, 16) + if _, err := rand.Read(buf); err != nil { + return fmt.Sprintf("aistudio-%x", time.Now().UnixNano()) + } + for i := range buf { + buf[i] = alphabet[int(buf[i])%len(alphabet)] + } + return "aistudio-" + string(buf) +} diff --git a/pkg/wsrelay/message.go b/pkg/wsrelay/message.go new file mode 100644 index 0000000..5d667c7 --- /dev/null +++ b/pkg/wsrelay/message.go @@ -0,0 +1,19 @@ +package wsrelay + +// Message represents the JSON payload exchanged with websocket clients. +type Message struct { + ID string `json:"id"` + Type string `json:"type"` + Payload map[string]any `json:"payload,omitempty"` +} + +const ( + MessageTypeHTTPReq = "http_request" + MessageTypeHTTPResp = "http_response" + MessageTypeStreamStart = "stream_start" + MessageTypeStreamChunk = "stream_chunk" + MessageTypeStreamEnd = "stream_end" + MessageTypeError = "error" + MessageTypePing = "ping" + MessageTypePong = "pong" +) diff --git a/pkg/wsrelay/session.go b/pkg/wsrelay/session.go new file mode 100644 index 0000000..b33b022 --- /dev/null +++ b/pkg/wsrelay/session.go @@ -0,0 +1,183 @@ +package wsrelay + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +const ( + readTimeout = 60 * time.Second + writeTimeout = 10 * time.Second + maxInboundMessageLen = 64 << 20 + heartbeatInterval = 30 * time.Second +) + +var errClosed = errors.New("websocket session closed") + +type pendingRequest struct { + ch chan Message + closeOnce sync.Once +} + +func (pr *pendingRequest) close() { + if pr == nil { + return + } + pr.closeOnce.Do(func() { close(pr.ch) }) +} + +type session struct { + conn *websocket.Conn + manager *Manager + provider string + id string + closed chan struct{} + closeOnce sync.Once + writeMutex sync.Mutex + pending sync.Map +} + +func newSession(conn *websocket.Conn, mgr *Manager, id string) *session { + s := &session{ + conn: conn, + manager: mgr, + id: id, + closed: make(chan struct{}), + } + conn.SetReadLimit(maxInboundMessageLen) + conn.SetReadDeadline(time.Now().Add(readTimeout)) + conn.SetPongHandler(func(string) error { + return conn.SetReadDeadline(time.Now().Add(readTimeout)) + }) + s.startHeartbeat() + return s +} + +func (s *session) startHeartbeat() { + if s == nil || s.conn == nil { + return + } + ticker := time.NewTicker(heartbeatInterval) + go func() { + defer ticker.Stop() + for { + select { + case <-s.closed: + return + case <-ticker.C: + s.writeMutex.Lock() + err := s.conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(writeTimeout)) + s.writeMutex.Unlock() + if err != nil { + s.cleanup(err) + return + } + } + } + }() +} + +func (s *session) run(_ context.Context) { + defer s.cleanup(errClosed) + for { + var msg Message + if err := s.conn.ReadJSON(&msg); err != nil { + s.cleanup(err) + return + } + s.dispatch(msg) + } +} + +func (s *session) dispatch(msg Message) { + if msg.Type == MessageTypePing { + _ = s.send(context.Background(), Message{ID: msg.ID, Type: MessageTypePong}) + return + } + if value, ok := s.pending.Load(msg.ID); ok { + req := value.(*pendingRequest) + select { + case req.ch <- msg: + default: + } + if msg.Type == MessageTypeHTTPResp || msg.Type == MessageTypeError || msg.Type == MessageTypeStreamEnd { + if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded { + actual.(*pendingRequest).close() + } + } + return + } + if msg.Type == MessageTypeHTTPResp || msg.Type == MessageTypeError || msg.Type == MessageTypeStreamEnd { + s.manager.logDebugf("wsrelay: received terminal message for unknown id %s (provider=%s)", msg.ID, s.provider) + } +} + +func (s *session) send(_ context.Context, msg Message) error { + select { + case <-s.closed: + return errClosed + default: + } + s.writeMutex.Lock() + defer s.writeMutex.Unlock() + if err := s.conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil { + return fmt.Errorf("set write deadline: %w", err) + } + if err := s.conn.WriteJSON(msg); err != nil { + return fmt.Errorf("write json: %w", err) + } + return nil +} + +func (s *session) request(ctx context.Context, msg Message) (<-chan Message, error) { + if msg.ID == "" { + return nil, fmt.Errorf("wsrelay: message id is required") + } + if _, loaded := s.pending.LoadOrStore(msg.ID, &pendingRequest{ch: make(chan Message, 8)}); loaded { + return nil, fmt.Errorf("wsrelay: duplicate message id %s", msg.ID) + } + value, _ := s.pending.Load(msg.ID) + req := value.(*pendingRequest) + if err := s.send(ctx, msg); err != nil { + if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded { + actual.(*pendingRequest).close() + } + return nil, err + } + go func() { + select { + case <-ctx.Done(): + if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded { + actual.(*pendingRequest).close() + } + case <-s.closed: + } + }() + return req.ch, nil +} + +func (s *session) cleanup(cause error) { + s.closeOnce.Do(func() { + close(s.closed) + s.pending.Range(func(key, value any) bool { + req := value.(*pendingRequest) + msg := Message{ID: key.(string), Type: MessageTypeError, Payload: map[string]any{"error": cause.Error()}} + select { + case req.ch <- msg: + default: + } + req.close() + return true + }) + s.pending = sync.Map{} + _ = s.conn.Close() + if s.manager != nil { + s.manager.handleSessionClosed(s, cause) + } + }) +}