diff --git a/cluster/parsecfg.go b/cluster/parsecfg.go index 72e104b..564f78a 100644 --- a/cluster/parsecfg.go +++ b/cluster/parsecfg.go @@ -7,6 +7,7 @@ import ( "github.com/duanhf2012/origin/v2/rpc" jsoniter "github.com/json-iterator/go" "gopkg.in/yaml.v3" + "io/fs" "os" "path/filepath" "strings" @@ -70,12 +71,8 @@ type NodeInfoList struct { NodeList []NodeInfo } -func validConfigFile(f os.DirEntry) bool { - if f.IsDir() == true || (filepath.Ext(f.Name()) != ".json" && filepath.Ext(f.Name()) != ".yml" && filepath.Ext(f.Name()) != ".yaml") { - return false - } - - return true +func validConfigFile(f string) bool { + return strings.HasSuffix(f, ".json")|| strings.HasSuffix(f, ".yml") || strings.HasSuffix(f, ".yaml") } func yamlToJson(data []byte, v interface{}) ([]byte, error) { @@ -277,32 +274,33 @@ func (cls *Cluster) readLocalClusterConfig(nodeId string) (DiscoveryInfo, []Node var discoveryInfo DiscoveryInfo var rpcMode RpcMode - clusterCfgPath := strings.TrimRight(configDir, "/") + "/cluster" - fileInfoList, err := os.ReadDir(clusterCfgPath) - if err != nil { - return discoveryInfo, nil, rpcMode, fmt.Errorf("read dir %s is fail :%+v", clusterCfgPath, err) - } - //读取任何文件,只读符合格式的配置,目录下的文件可以自定义分文件 - for _, f := range fileInfoList { - if !validConfigFile(f) { - continue + err := filepath.Walk(configDir, func(path string, info fs.FileInfo, err error)error { + if info.IsDir() { + return nil } - filePath := strings.TrimRight(strings.TrimRight(clusterCfgPath, "/"), "\\") + "/" + f.Name() - fileNodeInfoList, rErr := cls.ReadClusterConfig(filePath) + if err != nil { + return err + } + + if !validConfigFile(info.Name()) { + return nil + } + + fileNodeInfoList, rErr := cls.ReadClusterConfig(path) if rErr != nil { - return discoveryInfo, nil, rpcMode, fmt.Errorf("read file path %s is error:%+v", filePath, rErr) + return fmt.Errorf("read file path %s is error:%+v", path, rErr) } err = cls.SetRpcMode(&fileNodeInfoList.RpcMode, &rpcMode) if err != nil { - return discoveryInfo, nil, rpcMode, err + return err } err = discoveryInfo.setDiscovery(&fileNodeInfoList.Discovery) if err != nil { - return discoveryInfo, nil, rpcMode, err + return err } for _, nodeInfo := range fileNodeInfoList.NodeList { @@ -310,6 +308,12 @@ func (cls *Cluster) readLocalClusterConfig(nodeId string) (DiscoveryInfo, []Node nodeInfoList = append(nodeInfoList, nodeInfo) } } + + return nil + }) + + if err != nil { + return discoveryInfo, nil, rpcMode, err } if nodeId != rpc.NodeIdNull && (len(nodeInfoList) != 1) { @@ -331,32 +335,32 @@ func (cls *Cluster) readLocalClusterConfig(nodeId string) (DiscoveryInfo, []Node } func (cls *Cluster) readLocalService(localNodeId string) error { - clusterCfgPath := strings.TrimRight(configDir, "/") + "/cluster" - fileInfoList, err := os.ReadDir(clusterCfgPath) - if err != nil { - return fmt.Errorf("read dir %s is fail :%+v", clusterCfgPath, err) - } - var globalCfg interface{} publicService := map[string]interface{}{} nodeService := map[string]interface{}{} //读取任何文件,只读符合格式的配置,目录下的文件可以自定义分文件 - for _, f := range fileInfoList { - if !validConfigFile(f) { - continue + err := filepath.Walk(configDir, func(path string, info fs.FileInfo, err error)error{ + if info.IsDir() { + return nil } - filePath := strings.TrimRight(strings.TrimRight(clusterCfgPath, "/"), "\\") + "/" + f.Name() - currGlobalCfg, serviceConfig, mapNodeService, err := cls.readServiceConfig(filePath) if err != nil { - continue + return err + } + + if !validConfigFile(info.Name()) { + return nil + } + currGlobalCfg, serviceConfig, mapNodeService, err := cls.readServiceConfig(path) + if err != nil { + return err } if currGlobalCfg != nil { //不允许重复的配置global配置 if globalCfg != nil { - return fmt.Errorf("[Global] does not allow repeated configuration in %s", f.Name()) + return fmt.Errorf("[Global] does not allow repeated configuration in %s", info.Name()) } globalCfg = currGlobalCfg } @@ -372,7 +376,7 @@ func (cls *Cluster) readLocalService(localNodeId string) error { pubCfg, ok := serviceConfig[s] if ok == true { if _, publicOk := publicService[s]; publicOk == true { - return fmt.Errorf("public service [%s] does not allow repeated configuration in %s", s, f.Name()) + return fmt.Errorf("public service [%s] does not allow repeated configuration in %s", s, info.Name()) } publicService[s] = pubCfg } @@ -388,12 +392,17 @@ func (cls *Cluster) readLocalService(localNodeId string) error { } if _, nodeOK := nodeService[s]; nodeOK == true { - return fmt.Errorf("NodeService NodeId[%s] Service[%s] does not allow repeated configuration in %s", cls.localNodeInfo.NodeId, s, f.Name()) + return fmt.Errorf("NodeService NodeId[%s] Service[%s] does not allow repeated configuration in %s", cls.localNodeInfo.NodeId, s, info.Name()) } nodeService[s] = nodeCfg break } } + return nil + }) + + if err != nil { + return err } //组合所有的配置