介绍
errgroup 是什么?
errgroup 是Go官方扩展库 golang.org/x/sync 中的并发原语,用于管理多个 goroutine 的协同执行和错误处理。它解决了标准库 sync.WaitGroup 无法便捷传递错误的问题,支持在并发任务中统一处理错误,并支持在首个错误发生时取消所有任务。
核心特性
- 错误聚合与传播
当任意一个 goroutine 返回错误时,errgroup 会立即取消其他未完成的任务,并将错误传递给主 goroutine,避免资源浪费。
- 批量任务取消
通过 context.Context 实现任务组的级联取消,一旦某个任务失败,整个任务组会被终止。
- 超时控制
支持结合 context.WithTimeout 设置任务组超时时间,防止长时间阻塞。
与标准库 sync.WaitGroup 的对比
相较于 sync.WaitGroup 仅提供等待机制,errgroup增加了以下能力:
- 错误自动收集与传播;
- 任务取消的原子性操作;
- 更简洁的并发任务管理代码结构。
使用案例
安装
1
| go get github.com/golang/sync
|
使用标准库的 sync.WaitGroup 收集并发错误
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
| package main
import ( "fmt" "net/http" "sync" )
func main() { var urls = []string{ "http://www.baidu.org/", "http://www.bilibili.com/", "http://www.somestupidname.com/", }
var errors []error var mu sync.Mutex var wg sync.WaitGroup
for _, url := range urls { wg.Add(1) go func(u string) { defer wg.Done() resp, err := http.Get(u) if err != nil { mu.Lock() errors = append(errors, fmt.Errorf("访问 %s 失败: %v", u, err)) mu.Unlock() return } defer resp.Body.Close() fmt.Printf("访问 %s 成功,状态码: %s\n", u, resp.Status) }(url) }
wg.Wait()
if len(errors) > 0 { fmt.Println("\n收集到以下错误:") for i, err := range errors { fmt.Printf("错误 %d: %v\n", i+1, err) } } else { fmt.Println("\n所有任务执行成功,无错误") } }
|
执行结果:
1 2 3 4 5
| 访问 http://www.bilibili.com/ 成功,状态码: 200 OK 访问 http://www.somestupidname.com/ 成功,状态码: 200 OK
收集到以下错误: 错误 1: 访问 http://www.baidu.org/ 失败: Get "http://www.baidu.org/": EOF
|
这里必须要把所有的并发任务执行完成后,才可以返回所有的 goroutinue,而且还需要声明两个变量 errors 和 mutex 来确保并发安全。
使用 errgroup 来收集并发错误
安装 go get golang.org/x/sync
下面我们看看如果使用了 errgroup 可以怎么做:
1. 常规用法
使用 errgroup 来收集并发错误:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
| package main
import ( "fmt" "net/http" "sync"
"golang.org/x/sync/errgroup" )
func main() { var urls = []string{ "http://www.baidu.org/", "http://www.bilibili.com/", "http://www.somestupidname.xxyy/", } var allErrors []error var mu sync.Mutex var g errgroup.Group
for _, url := range urls { u := url g.Go(func() error { resp, err := http.Get(u) if err != nil { mu.Lock() allErrors = append(allErrors, fmt.Errorf("访问 %s 失败: %w", u, err)) mu.Unlock() return err } defer resp.Body.Close() fmt.Printf("访问 %s 成功,状态码: %s\n", u, resp.Status) return nil }) }
_ = g.Wait()
if len(allErrors) > 0 { fmt.Println("\n收集到所有错误:") for i, err := range allErrors { fmt.Printf("错误 %d: %v\n", i+1, err) } } else { fmt.Println("\n所有请求均成功,无错误") } }
|
执行结果:
1 2 3 4 5
| 访问 http://www.bilibili.com/ 成功,状态码: 200 OK
收集到所有错误: 错误 1: 访问 http://www.baidu.org/ 失败: Get "http://www.baidu.org/": EOF 错误 2: 访问 http://www.somestupidname.xxyy/ 失败: Get "http://www.somestupidname.xxyy/": EOF
|
2. 取消上下文
如果收集到一个错误后立刻取消其他 goroutinue,避免资源浪费,并在 Wait 方法中返回第一个非 nil 的错误:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
| package main
import ( "context" "fmt" "net/http" "sync"
"golang.org/x/sync/errgroup" )
func main() { var urls = []string{ "https://www.baidu.org/", "https://www.bilibili.com/", "https://www.somestupidname.xxyy/", }
ctx, cancel := context.WithCancel(context.Background()) defer cancel()
var allErrors []error var mu sync.Mutex
g, ctx := errgroup.WithContext(ctx)
for _, url := range urls { u := url g.Go(func() error { req, err := http.NewRequestWithContext(ctx, "GET", u, nil) if err != nil { mu.Lock() allErrors = append(allErrors, fmt.Errorf("创建请求 %s 失败: %w", u, err)) mu.Unlock() return err }
resp, err := http.DefaultClient.Do(req) if err != nil { mu.Lock() allErrors = append(allErrors, fmt.Errorf("访问 %s 失败: %w", u, err)) mu.Unlock() return err } defer resp.Body.Close()
select { case <-ctx.Done(): mu.Lock() allErrors = append(allErrors, fmt.Errorf("请求 %s 被取消: %w", u, ctx.Err())) mu.Unlock() return ctx.Err() default: fmt.Printf("访问 %s 成功,状态码: %s\n", u, resp.Status) return nil } }) }
if err := g.Wait(); err != nil { fmt.Println("Error: ", err) }
if len(allErrors) > 0 { fmt.Println("\n收集到所有错误:") for i, err := range allErrors { fmt.Printf("错误 %d: %v\n", i+1, err) } } else { fmt.Println("\n所有请求均成功,无错误") } }
|
执行结果:
1 2 3 4 5 6
| Error: Get "https://www.somestupidname.xxyy/": EOF
收集到所有错误: 错误 1: 访问 https://www.somestupidname.xxyy/ 失败: Get "https://www.somestupidname.xxyy/": EOF 错误 2: 访问 https://www.bilibili.com/ 失败: Get "https://www.bilibili.com/": context canceled 错误 3: 访问 https://www.baidu.org/ 失败: Get "https://www.baidu.org/": EOF
|
这里我们可以看到错误 2 的错误原因是上下文被取消造成的。
3. 限制并发数量
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
| package main
import ( "context" "fmt" "sync" "time"
"golang.org/x/sync/errgroup" )
func main() { tasks := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
var currentConcurrency int var mu sync.Mutex
g, ctx := errgroup.WithContext(context.Background()) g.SetLimit(2)
for _, task := range tasks { taskID := task g.Go(func() error { mu.Lock() currentConcurrency++ fmt.Printf("任务 %d 开始,当前并发数: %d\n", taskID, currentConcurrency) mu.Unlock()
select { case <-ctx.Done(): return ctx.Err() default: time.Sleep(500 * time.Millisecond) }
mu.Lock() currentConcurrency-- fmt.Printf("任务 %d 结束,当前并发数: %d\n", taskID, currentConcurrency) mu.Unlock()
return nil }) }
if err := g.Wait(); err != nil { fmt.Printf("执行出错: %v\n", err) } else { fmt.Println("所有任务执行完毕") } }
|
执行结果:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| 任务 2 开始,当前并发数: 1 任务 1 开始,当前并发数: 2 任务 1 结束,当前并发数: 1 任务 3 开始,当前并发数: 2 任务 2 结束,当前并发数: 1 任务 4 开始,当前并发数: 2 任务 4 结束,当前并发数: 1 任务 5 开始,当前并发数: 2 任务 3 结束,当前并发数: 1 任务 6 开始,当前并发数: 2 任务 6 结束,当前并发数: 1 任务 7 开始,当前并发数: 2 任务 5 结束,当前并发数: 1 任务 8 开始,当前并发数: 2 任务 8 结束,当前并发数: 1 任务 9 开始,当前并发数: 2 任务 7 结束,当前并发数: 1 任务 10 开始,当前并发数: 2 任务 9 结束,当前并发数: 1 任务 10 结束,当前并发数: 0 所有任务执行完毕
|
从执行结果看,并发数始终没有超过 2。
4. 尝试启动
errgroup 还提供了 errgroup.TryGo 可以尝试启动一个任务,它返回一个 bool 值,标识任务是否启动成功,true 表示成功,false 表示失败。
errgroup.TryGo 需要搭配 errgroup.SetLimit 一同使用,因为如果不限制并发数量,那么 errgroup.TryGo 始终返回 true,当达到最大并发数量限制时,errgroup.TryGo 返回 false。
示例如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
| package main
import ( "fmt" "time"
"golang.org/x/sync/errgroup" )
func main() { var g errgroup.Group g.SetLimit(3)
for i := 1; i <= 10; i++ { num := i if g.TryGo(func() error { fmt.Printf("goroutine %d 正在启动\n", num) time.Sleep(2 * time.Second) fmt.Printf("goroutine %d 已完成\n", num) return nil }) { fmt.Printf("goroutine %d 启动成功\n", num) } else { fmt.Printf("goroutine %d 无法启动(已达并发限制)\n", num) } }
if err := g.Wait(); err != nil { fmt.Printf("遇到错误:%v\n", err) }
fmt.Println("所有goroutine已完成。") }
|
执行结果:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| goroutine 1 启动成功 goroutine 1 正在启动 goroutine 2 启动成功 goroutine 3 启动成功 goroutine 4 无法启动(已达并发限制) goroutine 5 无法启动(已达并发限制) goroutine 6 无法启动(已达并发限制) goroutine 7 无法启动(已达并发限制) goroutine 8 无法启动(已达并发限制) goroutine 9 无法启动(已达并发限制) goroutine 10 无法启动(已达并发限制) goroutine 3 正在启动 goroutine 2 正在启动 goroutine 2 已完成 goroutine 1 已完成 goroutine 3 已完成 所有goroutine已完成。
|
参考