Files
clawgo/pkg/api/server_providers.go
2026-03-15 15:31:00 +08:00

424 lines
13 KiB
Go

package api
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
cfgpkg "github.com/YspCoder/clawgo/pkg/config"
"github.com/YspCoder/clawgo/pkg/providers"
)
func (s *Server) handleWebUIProviderOAuthStart(w http.ResponseWriter, r *http.Request) {
if !s.checkAuth(r) {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var body struct {
Provider string `json:"provider"`
AccountLabel string `json:"account_label"`
NetworkProxy string `json:"network_proxy"`
ProviderConfig cfgpkg.ProviderConfig `json:"provider_config"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
cfg, pc, err := s.resolveProviderConfig(strings.TrimSpace(body.Provider), body.ProviderConfig)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
_ = cfg
timeout := pc.TimeoutSec
if timeout <= 0 {
timeout = 90
}
loginMgr, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
flow, err := loginMgr.StartManualFlowWithOptions(providers.OAuthLoginOptions{
AccountLabel: body.AccountLabel,
NetworkProxy: firstNonEmptyString(strings.TrimSpace(body.NetworkProxy), strings.TrimSpace(pc.OAuth.NetworkProxy)),
})
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
flowID := fmt.Sprintf("%d", time.Now().UnixNano())
s.oauthFlowMu.Lock()
s.oauthFlows[flowID] = flow
s.oauthFlowMu.Unlock()
writeJSON(w, map[string]interface{}{
"ok": true,
"flow_id": flowID,
"mode": flow.Mode,
"auth_url": flow.AuthURL,
"user_code": flow.UserCode,
"instructions": flow.Instructions,
"account_label": strings.TrimSpace(body.AccountLabel),
"network_proxy": strings.TrimSpace(body.NetworkProxy),
})
}
func (s *Server) handleWebUIProviderOAuthComplete(w http.ResponseWriter, r *http.Request) {
if !s.checkAuth(r) {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var body struct {
Provider string `json:"provider"`
FlowID string `json:"flow_id"`
CallbackURL string `json:"callback_url"`
AccountLabel string `json:"account_label"`
NetworkProxy string `json:"network_proxy"`
ProviderConfig cfgpkg.ProviderConfig `json:"provider_config"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
cfg, pc, err := s.resolveProviderConfig(strings.TrimSpace(body.Provider), body.ProviderConfig)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
timeout := pc.TimeoutSec
if timeout <= 0 {
timeout = 90
}
loginMgr, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
s.oauthFlowMu.Lock()
flow := s.oauthFlows[strings.TrimSpace(body.FlowID)]
delete(s.oauthFlows, strings.TrimSpace(body.FlowID))
s.oauthFlowMu.Unlock()
if flow == nil {
http.Error(w, "oauth flow not found", http.StatusBadRequest)
return
}
session, models, err := loginMgr.CompleteManualFlowWithOptions(r.Context(), pc.APIBase, flow, body.CallbackURL, providers.OAuthLoginOptions{
AccountLabel: strings.TrimSpace(body.AccountLabel),
NetworkProxy: firstNonEmptyString(strings.TrimSpace(body.NetworkProxy), strings.TrimSpace(pc.OAuth.NetworkProxy)),
})
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if session.CredentialFile != "" {
pc.OAuth.CredentialFile = session.CredentialFile
pc.OAuth.CredentialFiles = appendUniqueStrings(pc.OAuth.CredentialFiles, session.CredentialFile)
}
if err := s.saveProviderConfig(cfg, body.Provider, pc); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]interface{}{
"ok": true,
"account": session.Email,
"credential_file": session.CredentialFile,
"network_proxy": session.NetworkProxy,
"models": models,
})
}
func (s *Server) handleWebUIProviderOAuthImport(w http.ResponseWriter, r *http.Request) {
if !s.checkAuth(r) {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if err := r.ParseMultipartForm(16 << 20); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
providerName := strings.TrimSpace(r.FormValue("provider"))
accountLabel := strings.TrimSpace(r.FormValue("account_label"))
networkProxy := strings.TrimSpace(r.FormValue("network_proxy"))
inlineCfgRaw := strings.TrimSpace(r.FormValue("provider_config"))
var inlineCfg cfgpkg.ProviderConfig
if inlineCfgRaw != "" {
if err := json.Unmarshal([]byte(inlineCfgRaw), &inlineCfg); err != nil {
http.Error(w, "invalid provider_config", http.StatusBadRequest)
return
}
}
cfg, pc, err := s.resolveProviderConfig(providerName, inlineCfg)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
file, header, err := r.FormFile("file")
if err != nil {
http.Error(w, "file required", http.StatusBadRequest)
return
}
defer file.Close()
data, err := io.ReadAll(file)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
timeout := pc.TimeoutSec
if timeout <= 0 {
timeout = 90
}
loginMgr, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
session, models, err := loginMgr.ImportAuthJSONWithOptions(r.Context(), pc.APIBase, header.Filename, data, providers.OAuthLoginOptions{
AccountLabel: accountLabel,
NetworkProxy: firstNonEmptyString(networkProxy, strings.TrimSpace(pc.OAuth.NetworkProxy)),
})
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if session.CredentialFile != "" {
pc.OAuth.CredentialFile = session.CredentialFile
pc.OAuth.CredentialFiles = appendUniqueStrings(pc.OAuth.CredentialFiles, session.CredentialFile)
}
if err := s.saveProviderConfig(cfg, providerName, pc); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]interface{}{
"ok": true,
"account": session.Email,
"credential_file": session.CredentialFile,
"network_proxy": session.NetworkProxy,
"models": models,
})
}
func (s *Server) handleWebUIProviderOAuthAccounts(w http.ResponseWriter, r *http.Request) {
if !s.checkAuth(r) {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
providerName := strings.TrimSpace(r.URL.Query().Get("provider"))
cfg, pc, err := s.loadProviderConfig(providerName)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
_ = cfg
timeout := pc.TimeoutSec
if timeout <= 0 {
timeout = 90
}
loginMgr, err := providers.NewOAuthLoginManager(pc, time.Duration(timeout)*time.Second)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
switch r.Method {
case http.MethodGet:
accounts, err := loginMgr.ListAccounts()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
writeJSON(w, map[string]interface{}{"ok": true, "accounts": accounts})
case http.MethodPost:
var body struct {
Action string `json:"action"`
CredentialFile string `json:"credential_file"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, "invalid json", http.StatusBadRequest)
return
}
action := strings.ToLower(strings.TrimSpace(body.Action))
handler := providerOAuthAccountActionHandlers[action]
if handler == nil {
http.Error(w, "unsupported action", http.StatusBadRequest)
return
}
result, err := handler(r.Context(), s, loginMgr, cfg, providerName, &pc, strings.TrimSpace(body.CredentialFile))
if err != nil {
status := http.StatusBadRequest
if action == "delete" && strings.Contains(strings.ToLower(err.Error()), "config") {
status = http.StatusInternalServerError
}
http.Error(w, err.Error(), status)
return
}
writeJSON(w, result)
default:
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}
type providerOAuthAccountActionHandler func(context.Context, *Server, *providers.OAuthLoginManager, *cfgpkg.Config, string, *cfgpkg.ProviderConfig, string) (map[string]interface{}, error)
var providerOAuthAccountActionHandlers = map[string]providerOAuthAccountActionHandler{
"refresh": func(ctx context.Context, _ *Server, loginMgr *providers.OAuthLoginManager, _ *cfgpkg.Config, _ string, _ *cfgpkg.ProviderConfig, credentialFile string) (map[string]interface{}, error) {
account, err := loginMgr.RefreshAccount(ctx, credentialFile)
if err != nil {
return nil, err
}
return map[string]interface{}{"ok": true, "account": account}, nil
},
"delete": func(_ context.Context, srv *Server, loginMgr *providers.OAuthLoginManager, cfg *cfgpkg.Config, providerName string, pc *cfgpkg.ProviderConfig, credentialFile string) (map[string]interface{}, error) {
if err := loginMgr.DeleteAccount(credentialFile); err != nil {
return nil, err
}
pc.OAuth.CredentialFiles = removeStringItem(pc.OAuth.CredentialFiles, credentialFile)
if strings.TrimSpace(pc.OAuth.CredentialFile) == strings.TrimSpace(credentialFile) {
pc.OAuth.CredentialFile = ""
if len(pc.OAuth.CredentialFiles) > 0 {
pc.OAuth.CredentialFile = pc.OAuth.CredentialFiles[0]
}
}
if err := srv.saveProviderConfig(cfg, providerName, *pc); err != nil {
return nil, err
}
return map[string]interface{}{"ok": true, "deleted": true}, nil
},
"clear_cooldown": func(_ context.Context, _ *Server, loginMgr *providers.OAuthLoginManager, _ *cfgpkg.Config, _ string, _ *cfgpkg.ProviderConfig, credentialFile string) (map[string]interface{}, error) {
if err := loginMgr.ClearCooldown(credentialFile); err != nil {
return nil, err
}
return map[string]interface{}{"ok": true, "cleared": true}, nil
},
}
func (s *Server) loadProviderConfig(name string) (*cfgpkg.Config, cfgpkg.ProviderConfig, error) {
if strings.TrimSpace(s.configPath) == "" {
return nil, cfgpkg.ProviderConfig{}, fmt.Errorf("config path not set")
}
cfg, err := cfgpkg.LoadConfig(s.configPath)
if err != nil {
return nil, cfgpkg.ProviderConfig{}, err
}
providerName := strings.TrimSpace(name)
if providerName == "" {
providerName = cfgpkg.PrimaryProviderName(cfg)
}
pc, ok := cfgpkg.ProviderConfigByName(cfg, providerName)
if !ok {
return nil, cfgpkg.ProviderConfig{}, fmt.Errorf("provider %q not found", providerName)
}
return cfg, pc, nil
}
func (s *Server) loadRuntimeProviderName(name string) (*cfgpkg.Config, string, error) {
if strings.TrimSpace(s.configPath) == "" {
return nil, "", fmt.Errorf("config path not set")
}
cfg, err := cfgpkg.LoadConfig(s.configPath)
if err != nil {
return nil, "", err
}
providerName := strings.TrimSpace(name)
if providerName == "" {
providerName = cfgpkg.PrimaryProviderName(cfg)
}
if !cfgpkg.ProviderExists(cfg, providerName) {
return nil, "", fmt.Errorf("provider %q not found", providerName)
}
return cfg, providerName, nil
}
func (s *Server) resolveProviderConfig(name string, inline cfgpkg.ProviderConfig) (*cfgpkg.Config, cfgpkg.ProviderConfig, error) {
if hasInlineProviderConfig(inline) {
cfg, err := cfgpkg.LoadConfig(s.configPath)
if err != nil {
return nil, cfgpkg.ProviderConfig{}, err
}
return cfg, inline, nil
}
return s.loadProviderConfig(name)
}
func hasInlineProviderConfig(pc cfgpkg.ProviderConfig) bool {
return strings.TrimSpace(pc.APIBase) != "" ||
strings.TrimSpace(pc.APIKey) != "" ||
len(pc.Models) > 0 ||
strings.TrimSpace(pc.Auth) != "" ||
strings.TrimSpace(pc.OAuth.Provider) != ""
}
func (s *Server) saveProviderConfig(cfg *cfgpkg.Config, name string, pc cfgpkg.ProviderConfig) error {
if cfg == nil {
return fmt.Errorf("config is nil")
}
providerName := strings.TrimSpace(name)
if cfg.Models.Providers == nil {
cfg.Models.Providers = map[string]cfgpkg.ProviderConfig{}
}
cfg.Models.Providers[providerName] = pc
if err := cfgpkg.SaveConfig(s.configPath, cfg); err != nil {
return err
}
if s.onConfigAfter != nil {
if err := s.onConfigAfter(); err != nil {
return err
}
} else {
if err := requestSelfReloadSignal(); err != nil {
return err
}
}
return nil
}
func appendUniqueStrings(values []string, item string) []string {
item = strings.TrimSpace(item)
if item == "" {
return values
}
for _, value := range values {
if strings.TrimSpace(value) == item {
return values
}
}
return append(values, item)
}
func removeStringItem(values []string, item string) []string {
item = strings.TrimSpace(item)
if item == "" {
return values
}
out := make([]string, 0, len(values))
for _, value := range values {
if strings.TrimSpace(value) == item {
continue
}
out = append(out, value)
}
return out
}
func atoiDefault(raw string, fallback int) int {
if value, err := strconv.Atoi(strings.TrimSpace(raw)); err == nil {
return value
}
return fallback
}