强曰为道
与天地相似,故不违。知周乎万物,而道济天下,故不过。旁行而不流,乐天知命,故不忧.
文档目录

Go 语言完全指南 / 30 - 实战项目:REST API、CLI 工具、微服务、爬虫

30 - 实战项目

30.1 REST API 服务

完整的 Todo API

// cmd/todo/main.go
package main

import (
    "context"
    "encoding/json"
    "fmt"
    "log"
    "net/http"
    "os"
    "os/signal"
    "strconv"
    "sync"
    "syscall"
    "time"
)

// ===== 模型 =====
type Todo struct {
    ID        int       `json:"id"`
    Title     string    `json:"title"`
    Completed bool      `json:"completed"`
    CreatedAt time.Time `json:"created_at"`
}

type CreateTodoRequest struct {
    Title string `json:"title" validate:"required"`
}

type UpdateTodoRequest struct {
    Title     *string `json:"title,omitempty"`
    Completed *bool   `json:"completed,omitempty"`
}

// ===== 存储层 =====
type TodoStore struct {
    mu     sync.RWMutex
    todos  map[int]*Todo
    nextID int
}

func NewTodoStore() *TodoStore {
    return &TodoStore{
        todos:  make(map[int]*Todo),
        nextID: 1,
    }
}

func (s *TodoStore) Create(title string) *Todo {
    s.mu.Lock()
    defer s.mu.Unlock()
    todo := &Todo{
        ID:        s.nextID,
        Title:     title,
        Completed: false,
        CreatedAt: time.Now(),
    }
    s.todos[s.nextID] = todo
    s.nextID++
    return todo
}

func (s *TodoStore) GetAll() []*Todo {
    s.mu.RLock()
    defer s.mu.RUnlock()
    result := make([]*Todo, 0, len(s.todos))
    for _, t := range s.todos {
        result = append(result, t)
    }
    return result
}

func (s *TodoStore) GetByID(id int) (*Todo, bool) {
    s.mu.RLock()
    defer s.mu.RUnlock()
    t, ok := s.todos[id]
    return t, ok
}

func (s *TodoStore) Update(id int, title *string, completed *bool) (*Todo, bool) {
    s.mu.Lock()
    defer s.mu.Unlock()
    t, ok := s.todos[id]
    if !ok {
        return nil, false
    }
    if title != nil {
        t.Title = *title
    }
    if completed != nil {
        t.Completed = *completed
    }
    return t, true
}

func (s *TodoStore) Delete(id int) bool {
    s.mu.Lock()
    defer s.mu.Unlock()
    if _, ok := s.todos[id]; !ok {
        return false
    }
    delete(s.todos, id)
    return true
}

// ===== 处理器 =====
type TodoHandler struct {
    store *TodoStore
}

func NewTodoHandler(store *TodoStore) *TodoHandler {
    return &TodoHandler{store: store}
}

func (h *TodoHandler) List(w http.ResponseWriter, r *http.Request) {
    todos := h.store.GetAll()
    respondJSON(w, http.StatusOK, todos)
}

func (h *TodoHandler) Create(w http.ResponseWriter, r *http.Request) {
    var req CreateTodoRequest
    if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
        respondError(w, http.StatusBadRequest, "Invalid request body")
        return
    }
    if req.Title == "" {
        respondError(w, http.StatusBadRequest, "Title is required")
        return
    }
    todo := h.store.Create(req.Title)
    respondJSON(w, http.StatusCreated, todo)
}

func (h *TodoHandler) Get(w http.ResponseWriter, r *http.Request) {
    id, err := strconv.Atoi(r.PathValue("id"))
    if err != nil {
        respondError(w, http.StatusBadRequest, "Invalid ID")
        return
    }
    todo, ok := h.store.GetByID(id)
    if !ok {
        respondError(w, http.StatusNotFound, "Todo not found")
        return
    }
    respondJSON(w, http.StatusOK, todo)
}

func (h *TodoHandler) Update(w http.ResponseWriter, r *http.Request) {
    id, err := strconv.Atoi(r.PathValue("id"))
    if err != nil {
        respondError(w, http.StatusBadRequest, "Invalid ID")
        return
    }
    var req UpdateTodoRequest
    if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
        respondError(w, http.StatusBadRequest, "Invalid request body")
        return
    }
    todo, ok := h.store.Update(id, req.Title, req.Completed)
    if !ok {
        respondError(w, http.StatusNotFound, "Todo not found")
        return
    }
    respondJSON(w, http.StatusOK, todo)
}

func (h *TodoHandler) Delete(w http.ResponseWriter, r *http.Request) {
    id, err := strconv.Atoi(r.PathValue("id"))
    if err != nil {
        respondError(w, http.StatusBadRequest, "Invalid ID")
        return
    }
    if !h.store.Delete(id) {
        respondError(w, http.StatusNotFound, "Todo not found")
        return
    }
    w.WriteHeader(http.StatusNoContent)
}

// ===== 中间件 =====
func loggingMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        start := time.Now()
        next.ServeHTTP(w, r)
        log.Printf("%s %s %v", r.Method, r.URL.Path, time.Since(start))
    })
}

func corsMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        w.Header().Set("Access-Control-Allow-Origin", "*")
        w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
        w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
        if r.Method == "OPTIONS" {
            w.WriteHeader(http.StatusOK)
            return
        }
        next.ServeHTTP(w, r)
    })
}

// ===== 工具函数 =====
func respondJSON(w http.ResponseWriter, status int, data any) {
    w.Header().Set("Content-Type", "application/json")
    w.WriteHeader(status)
    json.NewEncoder(w).Encode(data)
}

func respondError(w http.ResponseWriter, status int, message string) {
    respondJSON(w, status, map[string]string{"error": message})
}

// ===== 主函数 =====
func main() {
    store := NewTodoStore()
    handler := NewTodoHandler(store)

    mux := http.NewServeMux()

    // 健康检查
    mux.HandleFunc("GET /health", func(w http.ResponseWriter, r *http.Request) {
        respondJSON(w, http.StatusOK, map[string]string{"status": "ok"})
    })

    // Todo API(Go 1.22+ 路由语法)
    mux.HandleFunc("GET /api/todos", handler.List)
    mux.HandleFunc("POST /api/todos", handler.Create)
    mux.HandleFunc("GET /api/todos/{id}", handler.Get)
    mux.HandleFunc("PUT /api/todos/{id}", handler.Update)
    mux.HandleFunc("DELETE /api/todos/{id}", handler.Delete)

    // 应用中间件
    app := loggingMiddleware(corsMiddleware(mux))

    server := &http.Server{
        Addr:         ":8080",
        Handler:      app,
        ReadTimeout:  10 * time.Second,
        WriteTimeout: 10 * time.Second,
    }

    // 优雅关闭
    quit := make(chan os.Signal, 1)
    signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)

    go func() {
        log.Println("Server starting on :8080")
        if err := server.ListenAndServe(); err != http.ErrServerClosed {
            log.Fatal(err)
        }
    }()

    <-quit
    log.Println("Shutting down...")
    ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
    defer cancel()
    server.Shutdown(ctx)
    log.Println("Server stopped")
}
# 测试 API
curl -X POST http://localhost:8080/api/todos -d '{"title":"学习 Go"}'
curl http://localhost:8080/api/todos
curl -X PUT http://localhost:8080/api/todos/1 -d '{"completed":true}'
curl -X DELETE http://localhost:8080/api/todos/1

30.2 CLI 工具

文件搜索工具

// cmd/findfile/main.go
package main

import (
    "fmt"
    "os"
    "path/filepath"
    "regexp"
    "strings"
    "sync"
    "time"

    "github.com/spf13/cobra"
)

type FileMatch struct {
    Path    string
    Size    int64
    ModTime time.Time
}

func main() {
    var (
        pattern  string
        ext      string
        maxSize  int64
        workers  int
    )

    rootCmd := &cobra.Command{
        Use:   "findfile [directory]",
        Short: "快速文件搜索工具",
        Args:  cobra.ExactArgs(1),
        RunE: func(cmd *cobra.Command, args []string) error {
            dir := args[0]
            return searchFiles(dir, pattern, ext, maxSize, workers)
        },
    }

    rootCmd.Flags().StringVarP(&pattern, "pattern", "p", "", "文件名正则模式")
    rootCmd.Flags().StringVarP(&ext, "ext", "e", "", "文件扩展名(如 .go)")
    rootCmd.Flags().Int64VarP(&maxSize, "max-size", "s", 0, "最大文件大小(字节)")
    rootCmd.Flags().IntVarP(&workers, "workers", "w", 4, "并发 worker 数")

    rootCmd.Execute()
}

func searchFiles(dir, pattern, ext string, maxSize int64, workers int) error {
    var re *regexp.Regexp
    if pattern != "" {
        var err error
        re, err = regexp.Compile(pattern)
        if err != nil {
            return fmt.Errorf("invalid pattern: %w", err)
        }
    }

    files := make(chan string, 1000)
    matches := make(chan FileMatch, 100)
    var wg sync.WaitGroup

    // 文件遍历 goroutine
    go func() {
        defer close(files)
        filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
            if err != nil {
                return nil
            }
            if !info.IsDir() {
                files <- path
            }
            return nil
        })
    }()

    // Worker goroutines
    for i := 0; i < workers; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            for path := range files {
                info, err := os.Stat(path)
                if err != nil {
                    continue
                }
                if ext != "" && filepath.Ext(path) != ext {
                    continue
                }
                if maxSize > 0 && info.Size() > maxSize {
                    continue
                }
                if re != nil && !re.MatchString(filepath.Base(path)) {
                    continue
                }
                matches <- FileMatch{
                    Path:    path,
                    Size:    info.Size(),
                    ModTime: info.ModTime(),
                }
            }
        }()
    }

    go func() {
        wg.Wait()
        close(matches)
    }()

    count := 0
    for m := range matches {
        count++
        size := formatSize(m.Size)
        fmt.Printf("%s  %8s  %s\n", m.ModTime.Format("2006-01-02 15:04"), size, m.Path)
    }
    fmt.Printf("\n找到 %d 个文件\n", count)
    return nil
}

func formatSize(bytes int64) string {
    const unit = 1024
    if bytes < unit {
        return fmt.Sprintf("%d B", bytes)
    }
    div, exp := int64(unit), 0
    for n := bytes / unit; n >= unit; n /= unit {
        div *= unit
        exp++
    }
    return fmt.Sprintf("%.1f %c", float64(bytes)/float64(div), "KMGTPE"[exp])
}

30.3 微服务示例

用户服务

// 简化的 gRPC 用户服务
// user.proto
syntax = "proto3";
package user;
option go_package = "./pb";

service UserService {
    rpc GetUser(GetUserRequest) returns (User);
    rpc CreateUser(CreateUserRequest) returns (User);
    rpc ListUsers(ListUsersRequest) returns (ListUsersResponse);
}

message User {
    int32 id = 1;
    string name = 2;
    string email = 3;
}

message GetUserRequest { int32 id = 1; }
message CreateUserRequest { string name = 1; string email = 2; }
message ListUsersRequest { int32 page = 1; int32 page_size = 2; }
message ListUsersResponse {
    repeated User users = 1;
    int32 total = 2;
}
// server/main.go
package main

import (
    "context"
    "log"
    "net"
    
    pb "myproject/pb"
    "google.golang.org/grpc"
)

type userServer struct {
    pb.UnimplementedUserServiceServer
    users  map[int32]*pb.User
    nextID int32
}

func (s *userServer) GetUser(ctx context.Context, req *pb.GetUserRequest) (*pb.User, error) {
    user, ok := s.users[req.Id]
    if !ok {
        return nil, status.Errorf(codes.NotFound, "user %d not found", req.Id)
    }
    return user, nil
}

func (s *userServer) CreateUser(ctx context.Context, req *pb.CreateUserRequest) (*pb.User, error) {
    user := &pb.User{
        Id:    s.nextID,
        Name:  req.Name,
        Email: req.Email,
    }
    s.users[s.nextID] = user
    s.nextID++
    return user, nil
}

func main() {
    lis, err := net.Listen("tcp", ":50051")
    if err != nil {
        log.Fatal(err)
    }
    
    srv := grpc.NewServer()
    pb.RegisterUserServiceServer(srv, &userServer{
        users: make(map[int32]*pb.User),
    })
    
    log.Println("gRPC server starting on :50051")
    srv.Serve(lis)
}

30.4 网络爬虫

package main

import (
    "fmt"
    "io"
    "net/http"
    "regexp"
    "sync"
    "time"
)

type Crawler struct {
    client   *http.Client
    visited  map[string]bool
    mu       sync.Mutex
    wg       sync.WaitGroup
    sem      chan struct{} // 并发限制
}

func NewCrawler(maxConcurrency int) *Crawler {
    return &Crawler{
        client: &http.Client{
            Timeout: 10 * time.Second,
        },
        visited: make(map[string]bool),
        sem:     make(chan struct{}, maxConcurrency),
    }
}

func (c *Crawler) Crawl(url string, depth int) {
    if depth <= 0 {
        return
    }

    c.mu.Lock()
    if c.visited[url] {
        c.mu.Unlock()
        return
    }
    c.visited[url] = true
    c.mu.Unlock()

    c.wg.Add(1)
    c.sem <- struct{}{} // 获取信号量

    go func() {
        defer c.wg.Done()
        defer func() { <-c.sem }() // 释放信号量

        fmt.Printf("Crawling: %s (depth=%d)\n", url, depth)

        resp, err := c.client.Get(url)
        if err != nil {
            fmt.Printf("Error fetching %s: %v\n", url, err)
            return
        }
        defer resp.Body.Close()

        body, err := io.ReadAll(resp.Body)
        if err != nil {
            return
        }

        // 提取链接
        links := extractLinks(string(body), url)
        for _, link := range links {
            c.Crawl(link, depth-1)
        }
    }()
}

func extractLinks(html, baseURL string) []string {
    re := regexp.MustCompile(`href="(https?://[^"]+)"`)
    matches := re.FindAllStringSubmatch(html, -1)
    var links []string
    for _, match := range matches {
        if len(match) > 1 {
            links = append(links, match[1])
        }
    }
    return links
}

func main() {
    crawler := NewCrawler(5) // 最多 5 个并发
    crawler.Crawl("https://example.com", 2)
    crawler.wg.Wait()
    fmt.Printf("共访问 %d 个页面\n", len(crawler.visited))
}

30.5 并发下载器

package main

import (
    "context"
    "fmt"
    "io"
    "net/http"
    "os"
    "path/filepath"
    "sync"
    "time"

    "golang.org/x/sync/errgroup"
)

type DownloadTask struct {
    URL      string
    Filename string
}

func downloadFile(ctx context.Context, url, dest string) error {
    req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
    if err != nil {
        return err
    }

    resp, err := http.DefaultClient.Do(req)
    if err != nil {
        return err
    }
    defer resp.Body.Close()

    if resp.StatusCode != http.StatusOK {
        return fmt.Errorf("status %d", resp.StatusCode)
    }

    os.MkdirAll(filepath.Dir(dest), 0755)
    f, err := os.Create(dest)
    if err != nil {
        return err
    }
    defer f.Close()

    _, err = io.Copy(f, resp.Body)
    return err
}

func downloadAll(tasks []DownloadTask, maxConcurrency int) error {
    g, ctx := errgroup.WithContext(context.Background())
    g.SetLimit(maxConcurrency)

    for _, task := range tasks {
        task := task
        g.Go(func() error {
            fmt.Printf("Downloading: %s\n", task.Filename)
            err := downloadFile(ctx, task.URL, task.Filename)
            if err != nil {
                return fmt.Errorf("download %s: %w", task.URL, err)
            }
            fmt.Printf("Done: %s\n", task.Filename)
            return nil
        })
    }

    return g.Wait()
}

func main() {
    tasks := []DownloadTask{
        {URL: "https://example.com/file1.zip", Filename: "downloads/file1.zip"},
        {URL: "https://example.com/file2.zip", Filename: "downloads/file2.zip"},
        {URL: "https://example.com/file3.zip", Filename: "downloads/file3.zip"},
    }

    start := time.Now()
    if err := downloadAll(tasks, 3); err != nil {
        fmt.Println("Error:", err)
    }
    fmt.Printf("All downloads completed in %v\n", time.Since(start))
}

🏢 业务场景

  1. REST API:构建完整的 CRUD 服务,如用户管理、订单系统
  2. CLI 工具:开发运维工具、数据库迁移工具、代码生成器
  3. 微服务:gRPC 服务间通信,服务发现和负载均衡
  4. 数据采集:并发爬虫、API 数据抓取、文件下载器
  5. 消息处理:Kafka/RabbitMQ 消费者、事件驱动架构

📖 扩展阅读