Go语言学习之对象关系映射GORM

张开发
2026/4/7 19:24:03 15 分钟阅读

分享文章

Go语言学习之对象关系映射GORM
在 Gin 框架中集成 ORMGORM 是绝对的主流选择——成熟、功能完善、生态丰富。我来给你一个从零到生产的完整方案。技术栈选择ORM特点适用场景GORM功能最全、生态最好、Auto Migration通用项目、快速开发EntFacebook 出品、代码生成、类型安全大型项目、强类型需求sqlx轻量、原生 SQL、无魔法性能敏感、SQL 掌控sqlc编译期生成、类型安全追求性能和类型安全本篇以 GORM 为核心覆盖 90% 的业务场景。项目结构project/ ├── main.go ├── config/ │ └── config.go ├── models/ │ └── user.go ├── controllers/ │ └── user_controller.go ├── repositories/ │ └── user_repository.go ├── services/ │ └── user_service.go ├── middlewares/ │ └── auth.go └── database/ └── database.go数据库连接与配置// database/database.go package database import ( fmt log time gorm.io/driver/mysql gorm.io/driver/postgres gorm.io/driver/sqlite gorm.io/gorm gorm.io/gorm/logger ) type Config struct { Driver string DSN string MaxIdleConns int MaxOpenConns int ConnMaxLifetime time.Duration LogLevel logger.LogLevel } func NewDB(cfg Config) (*gorm.DB, error) { var ( db *gorm.DB err error ) // 根据驱动选择数据库 switch cfg.Driver { case mysql: db, err gorm.Open(mysql.Open(cfg.DSN), gorm.Config{ Logger: logger.Default.LogMode(cfg.LogLevel), }) case postgres: db, err gorm.Open(postgres.Open(cfg.DSN), gorm.Config{ Logger: logger.Default.LogMode(cfg.LogLevel), }) case sqlite: db, err gorm.Open(sqlite.Open(cfg.DSN), gorm.Config{ Logger: logger.Default.LogMode(cfg.LogLevel), }) default: return nil, fmt.Errorf(unsupported driver: %s, cfg.Driver) } if err ! nil { return nil, fmt.Errorf(failed to connect database: %w, err) } // 连接池配置 sqlDB, err : db.DB() if err ! nil { return nil, fmt.Errorf(failed to get sql.DB: %w, err) } sqlDB.SetMaxIdleConns(cfg.MaxIdleConns) sqlDB.SetMaxOpenConns(cfg.MaxOpenConns) sqlDB.SetConnMaxLifetime(cfg.ConnMaxLifetime) // 测试连接 if err : sqlDB.Ping(); err ! nil { return nil, fmt.Errorf(failed to ping database: %w, err) } log.Println(Database connected successfully) return db, nil }模型定义// models/user.go package models import ( time gorm.io/gorm ) type User struct { ID uint gorm:primaryKey json:id Username string gorm:size:50;uniqueIndex;not null json:username Email string gorm:size:100;uniqueIndex;not null json:email Password string gorm:size:255;not null json:- // 不暴露到 JSON Nickname string gorm:size:50 json:nickname Avatar string gorm:size:255 json:avatar Status int8 gorm:default:1;index json:status // 1:active 2:inactive 3:banned Role string gorm:size:20;default:user json:role LastLogin *time.Time json:last_login CreatedAt time.Time json:created_at UpdatedAt time.Time json:updated_at DeletedAt gorm.DeletedAt gorm:index json:- // 软删除 // 关联关系 Posts []Post gorm:foreignKey:AuthorID json:posts,omitempty } func (User) TableName() string { return users } type Post struct { ID uint gorm:primaryKey json:id Title string gorm:size:200;not null json:title Content string gorm:type:text json:content AuthorID uint gorm:not null;index json:author_id Status int8 gorm:default:1;index json:status // 1:draft 2:published ViewCount int gorm:default:0 json:view_count CreatedAt time.Time json:created_at UpdatedAt time.Time json:updated_at DeletedAt gorm.DeletedAt gorm:index json:- // 关联关系 Author User gorm:foreignKey:AuthorID json:author,omitempty Tags []Tag gorm:many2many:post_tags; json:tags,omitempty } func (Post) TableName() string { return posts } type Tag struct { ID uint gorm:primaryKey json:id Name string gorm:size:50;uniqueIndex;not null json:name CreatedAt time.Time json:created_at Posts []Post gorm:many2many:post_tags; json:posts,omitempty } func (Tag) TableName() string { return tags }Repository 层数据访问// repositories/user_repository.go package repositories import ( errors time gorm.io/gorm your-project/models ) type UserRepository struct { db *gorm.DB } func NewUserRepository(db *gorm.DB) *UserRepository { return UserRepository{db: db} } // Create 创建用户 func (r *UserRepository) Create(user *models.User) error { return r.db.Create(user).Error } // FindByID 根据 ID 查询 func (r *UserRepository) FindByID(id uint) (*models.User, error) { var user models.User err : r.db.First(user, id).Error if err ! nil { return nil, err } return user, nil } // FindByUsername 根据用户名查询 func (r *UserRepository) FindByUsername(username string) (*models.User, error) { var user models.User err : r.db.Where(username ?, username).First(user).Error if err ! nil { return nil, err } return user, nil } // FindByEmail 根据邮箱查询 func (r *UserRepository) FindByEmail(email string) (*models.User, error) { var user models.User err : r.db.Where(email ?, email).First(user).Error if err ! nil { return nil, err } return user, nil } // FindAll 分页查询 func (r *UserRepository) FindAll(page, pageSize int) ([]models.User, int64, error) { var users []models.User var total int64 // 计算总数 if err : r.db.Model(models.User{}).Count(total).Error; err ! nil { return nil, 0, err } // 分页查询 offset : (page - 1) * pageSize err : r.db.Offset(offset).Limit(pageSize).Find(users).Error if err ! nil { return nil, 0, err } return users, total, nil } // Update 更新用户 func (r *UserRepository) Update(user *models.User) error { return r.db.Save(user).Error } // UpdateFields 更新指定字段 func (r *UserRepository) UpdateFields(id uint, fields map[string]interface{}) error { return r.db.Model(models.User{}).Where(id ?, id).Updates(fields).Error } // UpdateLastLogin 更新最后登录时间 func (r *UserRepository) UpdateLastLogin(id uint) error { now : time.Now() return r.UpdateFields(id, map[string]interface{}{ last_login: now, }) } // Delete 软删除 func (r *UserRepository) Delete(id uint) error { return r.db.Delete(models.User{}, id).Error } // HardDelete 硬删除 func (r *UserRepository) HardDelete(id uint) error { return r.db.Unscoped().Delete(models.User{}, id).Error } // ExistsByUsername 检查用户名是否存在 func (r *UserRepository) ExistsByUsername(username string) (bool, error) { var count int64 err : r.db.Model(models.User{}).Where(username ?, username).Count(count).Error return count 0, err } // ExistsByEmail 检查邮箱是否存在 func (r *UserRepository) ExistsByEmail(email string) (bool, error) { var count int64 err : r.db.Model(models.User{}).Where(email ?, email).Count(count).Error return count 0, err } // FindWithPosts 查询用户及其文章预加载 func (r *UserRepository) FindWithPosts(id uint) (*models.User, error) { var user models.User err : r.db.Preload(Posts).First(user, id).Error if err ! nil { return nil, err } return user, nil } // FindActiveUsers 查询活跃用户 func (r *UserRepository) FindActiveUsers(limit int) ([]models.User, error) { var users []models.User err : r.db.Where(status ?, 1).Limit(limit).Find(users).Error return users, err }Service 层业务逻辑// services/user_service.go package services import ( errors time golang.org/x/crypto/bcrypt your-project/models your-project/repositories ) var ( ErrUserNotFound errors.New(user not found) ErrUsernameExists errors.New(username already exists) ErrEmailExists errors.New(email already exists) ErrInvalidCredentials errors.New(invalid credentials) ) type UserService struct { repo *repositories.UserRepository } func NewUserService(repo *repositories.UserRepository) *UserService { return UserService{repo: repo} } type CreateUserInput struct { Username string json:username binding:required,min3,max50 Email string json:email binding:required,email Password string json:password binding:required,min6 Nickname string json:nickname } type UpdateUserInput struct { Nickname string json:nickname Avatar string json:avatar } type LoginInput struct { Email string json:email binding:required,email Password string json:password binding:required } type UserResponse struct { ID uint json:id Username string json:username Email string json:email Nickname string json:nickname Avatar string json:avatar Status int8 json:status Role string json:role LastLogin *time.Time json:last_login CreatedAt time.Time json:created_at } // Create 创建用户 func (s *UserService) Create(input CreateUserInput) (*UserResponse, error) { // 检查用户名是否存在 if exists, _ : s.repo.ExistsByUsername(input.Username); exists { return nil, ErrUsernameExists } // 检查邮箱是否存在 if exists, _ : s.repo.ExistsByEmail(input.Email); exists { return nil, ErrEmailExists } // 密码加密 hashedPassword, err : bcrypt.GenerateFromPassword([]byte(input.Password), bcrypt.DefaultCost) if err ! nil { return nil, err } user : models.User{ Username: input.Username, Email: input.Email, Password: string(hashedPassword), Nickname: input.Nickname, Status: 1, Role: user, } if err : s.repo.Create(user); err ! nil { return nil, err } return s.toResponse(user), nil } // Login 用户登录 func (s *UserService) Login(input LoginInput) (*UserResponse, error) { user, err : s.repo.FindByEmail(input.Email) if err ! nil { return nil, ErrInvalidCredentials } // 验证密码 if err : bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(input.Password)); err ! nil { return nil, ErrInvalidCredentials } // 更新最后登录时间 _ s.repo.UpdateLastLogin(user.ID) return s.toResponse(user), nil } // GetByID 根据 ID 获取用户 func (s *UserService) GetByID(id uint) (*UserResponse, error) { user, err : s.repo.FindByID(id) if err ! nil { return nil, ErrUserNotFound } return s.toResponse(user), nil } // List 获取用户列表 func (s *UserService) List(page, pageSize int) ([]UserResponse, int64, error) { users, total, err : s.repo.FindAll(page, pageSize) if err ! nil { return nil, 0, err } responses : make([]UserResponse, len(users)) for i, user : range users { responses[i] *s.toResponse(user) } return responses, total, nil } // Update 更新用户 func (s *UserService) Update(id uint, input UpdateUserInput) (*UserResponse, error) { user, err : s.repo.FindByID(id) if err ! nil { return nil, ErrUserNotFound } if input.Nickname ! { user.Nickname input.Nickname } if input.Avatar ! { user.Avatar input.Avatar } if err : s.repo.Update(user); err ! nil { return nil, err } return s.toResponse(user), nil } // Delete 删除用户 func (s *UserService) Delete(id uint) error { return s.repo.Delete(id) } func (s *UserService) toResponse(user *models.User) *UserResponse { return UserResponse{ ID: user.ID, Username: user.Username, Email: user.Email, Nickname: user.Nickname, Avatar: user.Avatar, Status: user.Status, Role: user.Role, LastLogin: user.LastLogin, CreatedAt: user.CreatedAt, } }Controller 层// controllers/user_controller.go package controllers import ( net/http strconv github.com/gin-gonic/gin your-project/services ) type UserController struct { service *services.UserService } func NewUserController(service *services.UserService) *UserController { return UserController{service: service} } // Register 用户注册 func (c *UserController) Register(ctx *gin.Context) { var input services.CreateUserInput if err : ctx.ShouldBindJSON(input); err ! nil { ctx.JSON(http.StatusBadRequest, gin.H{error: err.Error()}) return } user, err : c.service.Create(input) if err ! nil { ctx.JSON(http.StatusConflict, gin.H{error: err.Error()}) return } ctx.JSON(http.StatusCreated, gin.H{ message: user created successfully, data: user, }) } // Login 用户登录 func (c *UserController) Login(ctx *gin.Context) { var input services.LoginInput if err : ctx.ShouldBindJSON(input); err ! nil { ctx.JSON(http.StatusBadRequest, gin.H{error: err.Error()}) return } user, err : c.service.Login(input) if err ! nil { ctx.JSON(http.StatusUnauthorized, gin.H{error: err.Error()}) return } // 生成 JWT Token示例 // token, _ : generateJWT(user.ID) ctx.JSON(http.StatusOK, gin.H{ message: login successful, data: user, // token: token, }) } // GetProfile 获取用户信息 func (c *UserController) GetProfile(ctx *gin.Context) { // 从中间件获取用户 ID userID, exists : ctx.Get(user_id) if !exists { ctx.JSON(http.StatusUnauthorized, gin.H{error: unauthorized}) return } user, err : c.service.GetByID(userID.(uint)) if err ! nil { ctx.JSON(http.StatusNotFound, gin.H{error: err.Error()}) return } ctx.JSON(http.StatusOK, gin.H{data: user}) } // GetUser 获取单个用户 func (c *UserController) GetUser(ctx *gin.Context) { id, err : strconv.ParseUint(ctx.Param(id), 10, 32) if err ! nil { ctx.JSON(http.StatusBadRequest, gin.H{error: invalid user id}) return } user, err : c.service.GetByID(uint(id)) if err ! nil { ctx.JSON(http.StatusNotFound, gin.H{error: err.Error()}) return } ctx.JSON(http.StatusOK, gin.H{data: user}) } // ListUsers 获取用户列表 func (c *UserController) ListUsers(ctx *gin.Context) { page, _ : strconv.Atoi(ctx.DefaultQuery(page, 1)) pageSize, _ : strconv.Atoi(ctx.DefaultQuery(page_size, 10)) if page 1 { page 1 } if pageSize 1 || pageSize 100 { pageSize 10 } users, total, err : c.service.List(page, pageSize) if err ! nil { ctx.JSON(http.StatusInternalServerError, gin.H{error: failed to get users}) return } ctx.JSON(http.StatusOK, gin.H{ data: users, meta: gin.H{ page: page, page_size: pageSize, total: total, }, }) } // UpdateUser 更新用户 func (c *UserController) UpdateUser(ctx *gin.Context) { id, err : strconv.ParseUint(ctx.Param(id), 10, 32) if err ! nil { ctx.JSON(http.StatusBadRequest, gin.H{error: invalid user id}) return } var input services.UpdateUserInput if err : ctx.ShouldBindJSON(input); err ! nil { ctx.JSON(http.StatusBadRequest, gin.H{error: err.Error()}) return } user, err : c.service.Update(uint(id), input) if err ! nil { ctx.JSON(http.StatusNotFound, gin.H{error: err.Error()}) return } ctx.JSON(http.StatusOK, gin.H{ message: user updated successfully, data: user, }) } // DeleteUser 删除用户 func (c *UserController) DeleteUser(ctx *gin.Context) { id, err : strconv.ParseUint(ctx.Param(id), 10, 32) if err ! nil { ctx.JSON(http.StatusBadRequest, gin.H{error: invalid user id}) return } if err : c.service.Delete(uint(id)); err ! nil { ctx.JSON(http.StatusNotFound, gin.H{error: err.Error()}) return } ctx.JSON(http.StatusOK, gin.H{message: user deleted successfully}) }主程序集成// main.go package main import ( log github.com/gin-gonic/gin gorm.io/gorm your-project/config your-project/controllers your-project/database your-project/middlewares your-project/models your-project/repositories your-project/services ) func main() { // 加载配置 cfg : config.Load() // 初始化数据库 db, err : database.NewDB(database.Config{ Driver: cfg.DB.Driver, DSN: cfg.DB.DSN, MaxIdleConns: cfg.DB.MaxIdleConns, MaxOpenConns: cfg.DB.MaxOpenConns, ConnMaxLifetime: cfg.DB.ConnMaxLifetime, LogLevel: gorm.LogLevel(cfg.DB.LogLevel), }) if err ! nil { log.Fatalf(Failed to connect database: %v, err) } // 自动迁移 if err : autoMigrate(db); err ! nil { log.Fatalf(Failed to migrate database: %v, err) } // 初始化依赖 userRepo : repositories.NewUserRepository(db) userService : services.NewUserService(userRepo) userController : controllers.NewUserController(userService) // 初始化 Gin r : gin.Default() // 注册路由 setupRoutes(r, userController) // 启动服务 log.Printf(Server starting on :%s, cfg.Server.Port) if err : r.Run(: cfg.Server.Port); err ! nil { log.Fatalf(Failed to start server: %v, err) } } func autoMigrate(db *gorm.DB) error { return db.AutoMigrate( models.User{}, models.Post{}, models.Tag{}, ) } func setupRoutes(r *gin.Engine, userCtrl *controllers.UserController) { // 公开路由 public : r.Group(/api/v1) { public.POST(/register, userCtrl.Register) public.POST(/login, userCtrl.Login) public.GET(/users, userCtrl.ListUsers) public.GET(/users/:id, userCtrl.GetUser) } // 需要认证的路由 protected : r.Group(/api/v1) protected.Use(middlewares.AuthMiddleware()) { protected.GET(/profile, userCtrl.GetProfile) protected.PUT(/users/:id, userCtrl.UpdateUser) protected.DELETE(/users/:id, userCtrl.DeleteUser) } }高级查询示例// repositories/post_repository.go package repositories import ( gorm.io/gorm your-project/models ) type PostRepository struct { db *gorm.DB } func NewPostRepository(db *gorm.DB) *PostRepository { return PostRepository{db: db} } // FindPublished 查询已发布文章分页 func (r *PostRepository) FindPublished(page, pageSize int) ([]models.Post, int64, error) { var posts []models.Post var total int64 query : r.db.Model(models.Post{}).Where(status ?, 2) // 2: published if err : query.Count(total).Error; err ! nil { return nil, 0, err } offset : (page - 1) * pageSize err : query.Preload(Author). Preload(Tags). Order(created_at DESC). Offset(offset). Limit(pageSize). Find(posts).Error return posts, total, err } // Search 搜索文章 func (r *PostRepository) Search(keyword string, page, pageSize int) ([]models.Post, int64, error) { var posts []models.Post var total int64 query : r.db.Model(models.Post{}). Where(status ?, 2). Where(title LIKE ? OR content LIKE ?, %keyword%, %keyword%) if err : query.Count(total).Error; err ! nil { return nil, 0, err } offset : (page - 1) * pageSize err : query.Preload(Author). Order(created_at DESC). Offset(offset). Limit(pageSize). Find(posts).Error return posts, total, err } // FindByTagID 根据标签查询 func (r *PostRepository) FindByTagID(tagID uint, page, pageSize int) ([]models.Post, int64, error) { var posts []models.Post var total int64 query : r.db.Model(models.Post{}). Joins(JOIN post_tags ON post_tags.post_id posts.id). Where(post_tags.tag_id ? AND posts.status ?, tagID, 2) if err : query.Count(total).Error; err ! nil { return nil, 0, err } offset : (page - 1) * pageSize err : query.Preload(Author). Preload(Tags). Order(posts.created_at DESC). Offset(offset). Limit(pageSize). Find(posts).Error return posts, total, err } // IncrementViewCount 增加浏览次数 func (r *PostRepository) IncrementViewCount(id uint) error { return r.db.Model(models.Post{}). Where(id ?, id). UpdateColumn(view_count, gorm.Expr(view_count ?, 1)). Error } // FindHotPosts 查询热门文章浏览量排序 func (r *PostRepository) FindHotPosts(limit int) ([]models.Post, error) { var posts []models.Post err : r.db.Where(status ?, 2). Order(view_count DESC). Limit(limit). Find(posts).Error return posts, err } // FindByAuthorID 查询作者的文章 func (r *PostRepository) FindByAuthorID(authorID uint, page, pageSize int) ([]models.Post, int64, error) { var posts []models.Post var total int64 query : r.db.Model(models.Post{}).Where(author_id ?, authorID) if err : query.Count(total).Error; err ! nil { return nil, 0, err } offset : (page - 1) * pageSize err : query.Order(created_at DESC). Offset(offset). Limit(pageSize). Find(posts).Error return posts, total, err }事务处理// services/post_service.go package services import ( errors gorm.io/gorm your-project/models your-project/repositories ) type PostService struct { db *gorm.DB postRepo *repositories.PostRepository tagRepo *repositories.TagRepository } func NewPostService(db *gorm.DB, postRepo *repositories.PostRepository, tagRepo *repositories.TagRepository) *PostService { return PostService{db: db, postRepo: postRepo, tagRepo: tagRepo} } type CreatePostInput struct { Title string json:title binding:required Content string json:content binding:required Tags []string json:tags } // Create 创建文章带事务 func (s *PostService) Create(authorID uint, input CreatePostInput) (*models.Post, error) { var post *models.Post err : s.db.Transaction(func(tx *gorm.DB) error { // 创建文章 post models.Post{ Title: input.Title, Content: input.Content, AuthorID: authorID, Status: 2, // published } if err : tx.Create(post).Error; err ! nil { return err } // 处理标签 if len(input.Tags) 0 { var tags []models.Tag for _, tagName : range input.Tags { var tag models.Tag // 查找或创建标签 result : tx.Where(name ?, tagName).FirstOrCreate(tag, models.Tag{Name: tagName}) if result.Error ! nil { return result.Error } tags append(tags, tag) } // 关联标签 if err : tx.Model(post).Association(Tags).Replace(tags); err ! nil { return err } } return nil }) if err ! nil { return nil, err } return post, nil } // Publish 发布文章状态变更 func (s *PostService) Publish(id uint) error { return s.db.Transaction(func(tx *gorm.DB) error { var post models.Post if err : tx.First(post, id).Error; err ! nil { return err } if post.Status 2 { return errors.New(post already published) } return tx.Model(post).Update(status, 2).Error }) }常见问题与优化// 1. N1 查询问题 // ❌ 错误循环中查询关联 posts, _ : postRepo.FindAll() for _, post : range posts { author, _ : userRepo.FindByID(post.AuthorID) // N 次查询 } // ✅ 正确使用 Preload 预加载 db.Preload(Author).Find(posts) // 2. 分页优化 // ❌ 大偏移量性能差 db.Offset(100000).Limit(10).Find(posts) // ✅ 使用游标分页 db.Where(id ?, lastID).Limit(10).Find(posts) // 3. 批量插入 // ❌ 循环插入 for _, user : range users { db.Create(user) // N 次数据库操作 } // ✅ 批量插入 db.CreateInBatches(users, 100) // 每批 100 条 // 4. 更新优化 // ❌ 全字段更新 db.Save(user) // 更新所有字段 // ✅ 更新指定字段 db.Model(user).Select(nickname, avatar).Updates(user) // 5. 查询优化 // ✅ 只查询需要的字段 db.Select(id, title, author_id).Find(posts) // ✅ 使用索引 db.Where(status ? AND created_at ?, 2, yesterday).Find(posts)这个方案提供了从数据库连接到业务逻辑的完整分层架构适合中大型项目的长期维护。你的项目现在用的是什么 ORM是否有特殊的性能需求或复杂查询场景我可以针对具体场景给出更深入的优化建议。

更多文章