Golang 基于 Redis 的简单 Session 扩展

虽然 Golang 有很多的 web 框架,但是我还是比较喜欢小巧的路由来扩展 web 应用,这样比较灵活,但是一般路由不提供 session 管理,所以这里简单地写了个基于 Redis 的 session 扩展, 代码:

package session

import (
    "context"
    "encoding/json"
    "github.com/go-redis/redis"
    "math/rand"
    "net/http"
    "strings"
    "time"
)

type StorageInterface interface {

    Set(key string, value []byte) error

    Get(key string) ([]byte, error)
}

type RedisStorage struct {
    cache *redis.Client
    ttl time.Duration
}

type Session struct {
    Prefix string `json:"prefix"`
    sessionId string
    storage StorageInterface
    Data map[string]string `json:"data"`
}

func NewSession(storage StorageInterface, sessionId string, prefix string) *Session {
    session := new(Session)
    session.Prefix = prefix
    session.sessionId = sessionId
    session.storage = storage
    session.Data = make(map[string]string)

    return session
}

func (s *Session) Get(key string) ([]byte, error) {

    data, err := s.storage.Get(s.Prefix + s.sessionId)
    if err != nil {
        return nil, err
    }
    err = json.Unmarshal(data, s)
    if err != nil {
        return nil, err
    }

    return []byte(s.Data[key]), nil
}

func (s *Session) Set(key string, value []byte) error {
    s.Data[key] = string(value)
    result, err := json.Marshal(s)
    if err == nil {
        err = s.storage.Set(s.Prefix + s.sessionId, result)
    }
    return err
}

// Redis
func NewRedisStorage(client *redis.Client, defaultTTL time.Duration) *RedisStorage {
    storage := new(RedisStorage)
    storage.cache = client
    storage.ttl = defaultTTL
    return storage
}

func(storage *RedisStorage) Set(key string, value []byte) error {
    duration := storage.ttl * time.Second
    return storage.cache.Set(key, value, duration).Err()
}

func(storage *RedisStorage) Get(key string) ([]byte, error) {
    data, err := storage.cache.Get(key).Result()

    return []byte(data), err
}

// 将 session 注入到 context
func SetSessionContext(ctx context.Context, sess *Session) context.Context {
    return context.WithValue(ctx, "session", sess)
}

func GetSessionFromRequest(r *http.Request) *Session {
    return r.Context().Value("session").(*Session)
}

func RandomString(length int) string {
    rand.Seed(time.Now().UnixNano())
    chars := []rune("abcdefghijklmnopqrstuvwxyz0123456789")
    var b strings.Builder
    for i := 0; i < length; i++ {
        b.WriteRune(chars[rand.Intn(len(chars))])
    }
    return b.String()
}

func (s *Session) UpdateExpire(sessionId string, prefix string) error {
    result, err := s.storage.Get(prefix + sessionId)
    if err != nil {
        return err
    }
    err = s.storage.Set(prefix + sessionId, result)
    if err != nil {
        return err
    }
    return nil
}

type Config struct {
    Prefix string
    Name string
    Path string
    Domain string
    Expires time.Time
    Secure bool
    HttpOnly bool
    SameSite http.SameSite
}

func MiddleWare(storage StorageInterface, config *Config) func(next http.Handler) http.Handler {
    return func(next http.Handler) http.Handler {
        return http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) {

            var session *Session
            cookie, err := r.Cookie(config.Name)
            // 未设置 session, 开启 session
            if err != nil {
                sessionId := RandomString(32)
                cookie = &http.Cookie{
                    Name:       config.Name,
                    Value:      sessionId,
                    Path:       config.Path,
                    Domain:     config.Domain,
                    Expires:    config.Expires,
                    RawExpires: "",
                    MaxAge:     0,
                    Secure:     config.Secure,
                    HttpOnly:   config.HttpOnly,
                    SameSite:   config.SameSite,
                    Raw:        "",
                    Unparsed:   nil,
                }
                session = NewSession(storage, sessionId, config.Prefix)
                http.SetCookie(w, cookie)
            } else {
                sessionId := cookie.Value
                session = NewSession(storage, sessionId, config.Prefix)
                session.UpdateExpire(sessionId, config.Prefix)
            }
            ctx := SetSessionContext(r.Context(), session)
            next.ServeHTTP(w, r.WithContext(ctx))
        })
    }
}

这里只实现了 Redis,也可以很容易地集成 Memcached。 使用方法:

package main

import (
    "./session"
    "fmt"
    "github.com/go-redis/redis"
    "github.com/gorilla/mux"
    "log"
    "net/http"
    "time"
)

func HomeHandler(w http.ResponseWriter, r *http.Request) {
    sess := session.GetSessionFromRequest(r)
    err := sess.Set("user", []byte("twn39"))
    if err != nil {
        panic(err)
    }
    name, _ := sess.Get("user")
    w.Write(name)
}

func main() {

    client := redis.NewClient(&redis.Options{
        Addr:     "localhost:6379",
        Password: "", // no password set
        DB:       0,  // use default DB
    })

    config := &session.Config{
        Prefix:   "",
        Name:     "GO_SESSION",
        Path:     "/",
        Domain:   "",
        Expires:  time.Now().Add(time.Duration(600) * time.Second),
        Secure:   false,
        HttpOnly: true,
        SameSite: 0,
    }
    fmt.Printf("config: %v", config)
    storage := session.NewRedisStorage(client, 600)
    route := mux.NewRouter()
    route.HandleFunc("/", HomeHandler)
    route.Use(session.MiddleWare(storage, config))

    log.Fatal(http.ListenAndServe(":8080", route))
}