一 什么是singleFlight
singleflight 主要是用来做并发控制,例如高并发场景下,N个请求同时查询一个redis key,如果能将这N个请求合并成一个redis查询,那么性能一定会提高很多。
常见的场景比如防止缓存击穿。
二 go 官方扩展包singleFlight
1.结构体
// Group represents a class of work and forms a namespace in
// which units of work can be executed with duplicate suppression.
type Group struct {
mu sync.Mutex // 保护变量m
m map[string]*call // 惰性初始化
}
// call is an in-flight or completed singleflight.Do call
type call struct {
wg sync.WaitGroup
// val err记录自定义fn函数执行结果,在 wg.Done前写入一次,wg.Done后只会读取val 和err
// These fields are written once before the WaitGroup is done
// and are only read after the WaitGroup is done.
val interface{}
err error
// These fields are read and written with the singleflight
// mutex held before the WaitGroup is done, and are read but
// not written after the WaitGroup is done.
dups int
chans []chan<- Result
}
// Result holds the results of Do, so they can be passed
// on a channel.
type Result struct {
Val interface{}
Err error
Shared bool
}
2.singleflight对外提供了3个方法
a) Do
key用于标识请求,fn()为调用者需要实现的业务逻辑; 返回参数有三个,v和err为fn()的返回值,shared表示返回结果是否是共享的。
在对同一个key多次调用时,Do 确保了fn()只会执行一次,若第一次的调用没有完成,其他调用会阻塞并等待首次调用返回。
// Do executes and returns the results of the given function, making
// sure that only one execution is in-flight for a given key at a
// time. If a duplicate comes in, the duplicate caller waits for the
// original to complete and receives the same results.
// The return value shared indicates whether v was given to multiple callers.
func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, err error, shared bool) {
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call)
}
//如果key 已经存在,说明已经有fn在执行或者已经执行完成了,第二次的及以后的后续
if c, ok := g.m[key]; ok {
c.dups++
g.mu.Unlock()
//如果是有fn请求正在执行,则堵塞等待,(doCall函数中 fn执行完成后 c.wg.Done()
c.wg.Wait()
if e, ok := c.err.(*panicError); ok {
panic(e)
} else if c.err == errGoexit {
runtime.Goexit()
}
//返回结果
return c.val, c.err, true
}
//第一次请求
c := new(call)
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()
g.doCall(c, key, fn)
return c.val, c.err, c.dups > 0
}
b) DoChan()
DoChan()
和Do()和
区别是DoChan()
属于异步调用,返回一个channel,解决同步调用时的阻塞问题;
// DoChan is like Do but returns a channel that will receive the
// results when they are ready.
//
// The returned channel will not be closed.
func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result {
ch := make(chan Result, 1)
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call)
}
//当已经有fn在执行时,直接返回channel,fn执行完成后,会将结果写入channel
if c, ok := g.m[key]; ok {
c.dups++
c.chans = append(c.chans, ch)
g.mu.Unlock()
return ch
}
c := &call{chans: []chan<- Result{ch}}
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()
go g.doCall(c, key, fn)
return ch
}
c) doCall()
doCall函数负责执行 自定义的fn函数。
// doCall handles the single call for a key.
func (g *Group) doCall(c *call, key string, fn func() (interface{}, error)) {
normalReturn := false
recovered := false
// use double-defer to distinguish panic from runtime.Goexit,
// more details see https://golang.org/cl/134395
defer func() {
// the given function invoked runtime.Goexit
if !normalReturn && !recovered {
c.err = errGoexit
}
g.mu.Lock()
defer g.mu.Unlock()
c.wg.Done()
if g.m[key] == c {
delete(g.m, key)
}
if e, ok := c.err.(*panicError); ok {
// In order to prevent the waiting channels from being blocked forever,
// needs to ensure that this panic cannot be recovered.
if len(c.chans) > 0 {
go panic(e)
select {} // Keep this goroutine around so that it will appear in the crash dump.
} else {
panic(e)
}
} else if c.err == errGoexit {
// Already in the process of goexit, no need to call again
} else {
// Normal return ,如果chans非空,将执行结果写入所有channel。()
for _, ch := range c.chans {
ch <- Result{c.val, c.err, c.dups > 0}
}
}
}()
func() {
defer func() {
if !normalReturn {
// Ideally, we would wait to take a stack trace until we've determined
// whether this is a panic or a runtime.Goexit.
//
// Unfortunately, the only way we can distinguish the two is to see
// whether the recover stopped the goroutine from terminating, and by
// the time we know that, the part of the stack trace relevant to the
// panic has been discarded.
if r := recover(); r != nil {
c.err = newPanicError(r)
}
}
}()
//写入fn的执行结果
c.val, c.err = fn()
normalReturn = true
}()
if !normalReturn {
recovered = true
}
}
总结:
-
singleflight使用sync.Mutex和sync.WaitGroup进行并发控制
-
对于key相同的请求, singleflight只会处理的一个进入的请求,后续的请求都使用waitGroup.wait()将请求阻塞
-
使用双重defer()区分了panic和runtime.Goexit错误,如果返回的是一个panic错误,group.c.chans会发生阻塞,那么需要抛出这个panic且确保其无法被recover
三 singleFlight使用示例
package main
import (
"errors"
"fmt"
"golang.org/x/net/context"
"golang.org/x/sync/singleflight"
_ "net/http/pprof"
"sync"
"sync/atomic"
"time"
)
var count int32
func TestSingleFlight() {
var (
wg sync.WaitGroup
now = time.Now()
n = 1000
sg = &singleflight.Group{}
)
atomic.StoreInt32(&count, 0)
//循环查询1000次
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
res, _ := getArticle(1)
if res != "article: 1" {
panic("err")
}
wg.Done()
}()
}
wg.Wait()
fmt.Printf("同时发起 %d 次请求,耗时: %s", count, time.Since(now))
fmt.Println("----------------------")
atomic.StoreInt32(&count, 0)
now = time.Now()
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
res, _ := singleFlightGetArticle(sg, 1)
//res, _ := getArticle(1)
if res != "article: 1" {
panic("err")
}
wg.Done()
}()
}
wg.Wait()
fmt.Printf("single Flight do同时发起 %d 次请求,耗时: %s", count, time.Since(now))
fmt.Println("------------")
//DoChan
atomic.StoreInt32(&count, 0)
now = time.Now()
for i := 0; i < n; i++ {
wg.Add(1)
go func() {
res, _ := singleFlightGetArticleChannel(sg, 1)
//res, _ := getArticle(1)
if res != "article: 1" {
panic("err")
}
wg.Done()
}()
}
wg.Wait()
fmt.Printf("singleFlight DoChan同时发起%d次请求,耗时: %s", count, time.Since(now))
}
func singleFlightGetArticle(sg *singleflight.Group, id int) (string, error) {
v, err, _ := sg.Do(fmt.Sprintf("%d", id), func() (interface{}, error) {
return getArticle(id)
})
return v.(string), err
}
/*
通过do channel, context 超时处理
*/
func singleFlightGetArticleChannel(sg *singleflight.Group, id int) (string, error) {
channel := sg.DoChan(fmt.Sprintf("%d", id), func() (interface{}, error) {
return getArticle(id)
})
ctx, _ := context.WithTimeout(context.Background(), 4*time.Second)
for { // 选择 context + select 超时控制
select {
case <-ctx.Done():
fmt.Println("singleFlightGetArticleChannel time out")
return "", errors.New("ctx-timeout") // 根据业务逻辑选择上抛 error
case data, _ := <-channel:
return data.Val.(string), nil
default:
fmt.Println("singleFlightGetArticleChannel default")
}
}
}
/*
模拟获取文章详情
*/
func getArticle(id int) (article string, err error) {
atomic.AddInt32(&count, 1)
//time.Sleep(time.Duration(count) * time.Millisecond)
fmt.Printf("getArticle run times %d", count)
return fmt.Sprintf("article: %d", id), nil
}