Gorm原始碼學習-建立行記錄

2022-12-17 18:00:18

1. 前言

Gorm原始碼學習系列

此文是Gorm原始碼學習系列的第二篇,主要梳理下通過Gorm建立表的流程。

 

2. 建立行記錄程式碼範例

gorm提供了以下幾個介面來建立行記錄

  • 一次建立一行 func (db *DB) Create(value interface{}) (tx *DB)
  • 批次建立 func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB)
  • 資料庫不存在主鍵時建立,存在時更新 func (db *DB) Save(value interface{}) (tx *DB)

詳細請看教學及原始碼finisher_api.go,這裡使用func (db *DB) Create(value interface{}) (tx *DB)來說明建立行記錄等大致流程。

 

2.1 宣告模型

type Stu struct {
	ID     int64 `gorm:"column:id; primary_key" json:"id"`
	Age    int64 `gorm:"column:age;"`
	Height int64 `gorm:"column:height;"`
	Weight int64 `gorm:"column:weight;"`
}

// 設定表名
func (Stu) TableName() string {
	return "t_student"
}

模型程式碼的主要用途如下,

  • 申明的表中有哪些列及每列的名稱、特性等,如gorm標籤指定每個字斷對於的表的列名
  • 通過實現Tabler介面指定了固定的表名,介面定義如下
type Tabler interface {
	TableName() string
}

關於模型定義中更多的約定和約束等,請看教學

出於分表等業務場景,我們並不希望固定模型等表名,gorm提供了func (db *DB) Table(name string, args ...interface{}) (tx *DB)等方法

來動態指定表名,詳情請看教學

 

2.2 建立行

func main() {
	// 資料庫連線, 具體檢視https://www.cnblogs.com/amos01/p/16890747.html 連線資料庫程式碼範例
	db, _ := dbOpen()
	// 開啟偵錯模式、會列印DML
	db = db.Debug()
	stu := &Stu{
		Age:    18,
		Height: 185,
		Weight: 70,
	}
	db = db.Create(stu)
	fmt.Printf("Error:%v ID:%v RowsAffected:%v\n", db.Error, stu.ID, db.RowsAffected)
}

程式碼輸出如下

$ go run main.go
2022/12/11 14:59:59 /Users/zbw/workspace/test/main.go:33
[1.910ms] [rows:1] INSERT INTO `t_student` (`age`,`height`,`weight`) VALUES (18,185,70)
Error:<nil> ID:1027 RowsAffected:1

從程式碼輸出可以看,行記錄的ID為1027,連線資料庫查詢,結果如下。

mysql> select * from t_student where id = 1027\G
*************************** 1. row ***************************
    id: 1027
   age: 18
height: 185
weight: 70
1 row in set (0.01 sec)

因此,我們帶著以下問題來梳理下Gorm建立行記錄的流程

  • 如何從model到DML語句的
  • 如何將ID寫入到model的 

 

3. 從Model到DML

func (db *DB) Create(value interface{}) (tx *DB)的實現如下

// Create inserts value, returning the inserted data's primary key in value's id
func (db *DB) Create(value interface{}) (tx *DB) {
	if db.CreateBatchSize > 0 {
		return db.CreateInBatches(value, db.CreateBatchSize)
	}

	tx = db.getInstance()
	tx.Statement.Dest = value
	return tx.callbacks.Create().Execute(tx)
}

func (p *processor) Execute(db *DB) *DB 的實現比較長,具體程式碼見github

總結下來,做了兩件主要的事情,

  • 解析model獲取表名、每列的定義等
  • 執行勾點函數以及建立行函數

X

3.1 資料結構理解

  • gorm.Statement
檢視gorm.Statement程式碼
// Statement statement
type Statement struct {
	*DB
	TableExpr            *clause.Expr
	Table                string      // 表名
	Model                interface{} // model定義
	Unscoped             bool
	Dest                 interface{} // model的另外一種表達,如map
	ReflectValue         reflect.Value
	Clauses              map[string]clause.Clause
	BuildClauses         []string
	Distinct             bool
	Selects              []string // selected columns
	Omits                []string // omit columns
	Joins                []join
	Preloads             map[string][]interface{}
	Settings             sync.Map
	ConnPool             ConnPool       // 資料庫連線
	Schema               *schema.Schema // 表結構化資訊
	Context              context.Context
	RaiseErrorOnNotFound bool
	SkipHooks            bool
	SQL                  strings.Builder // 最終的DML語句
	Vars                 []interface{}   // DML語句的引數值
	CurDestIndex         int             // 批次建立/更新時,gorm當前操作的陣列/slice的下標
	attrs                []interface{}
	assigns              []interface{}
	scopes               []func(*DB) *DB
}
  • schema.Schem
檢視schema.Schema程式碼
type Schema struct {
	Name                      string
	ModelType                 reflect.Type
	Table                     string // 表名
	PrioritizedPrimaryField   *Field
	DBNames                   []string // 表每列的名字
	PrimaryFields             []*Field
	PrimaryFieldDBNames       []string // 表的主鍵列明
	Fields                    []*Field // gorm自定義的model每個字短
	FieldsByName              map[string]*Field
	FieldsByDBName            map[string]*Field
	FieldsWithDefaultDBValue  []*Field // fields with default value assigned by database
	Relationships             Relationships
	CreateClauses             []clause.Interface // 建立行的子句
	QueryClauses              []clause.Interface
	UpdateClauses             []clause.Interface
	DeleteClauses             []clause.Interface
	BeforeCreate, AfterCreate bool
	BeforeUpdate, AfterUpdate bool
	BeforeDelete, AfterDelete bool
	BeforeSave, AfterSave     bool
	AfterFind                 bool
	err                       error
	initialized               chan struct{}
	namer                     Namer
	cacheStore                *sync.Map
}
  • schema.Field
檢視schema.Field程式碼
// Field is the representation of model schema's field
type Field struct {
	Name                   string // model的欄位名
	DBName                 string // 對應表的列名
	BindNames              []string
	DataType               DataType
	GORMDataType           DataType
	PrimaryKey             bool
	AutoIncrement          bool
	AutoIncrementIncrement int64
	Creatable              bool
	Updatable              bool
	Readable               bool
	AutoCreateTime         TimeType
	AutoUpdateTime         TimeType
	HasDefaultValue        bool
	DefaultValue           string
	DefaultValueInterface  interface{}
	NotNull                bool
	Unique                 bool
	Comment                string
	Size                   int
	Precision              int
	Scale                  int
	IgnoreMigration        bool
	FieldType              reflect.Type // 反射型別
	IndirectFieldType      reflect.Type // 反射型別
	StructField            reflect.StructField // model欄位資訊
	Tag                    reflect.StructTag // tag
	TagSettings            map[string]string
	Schema                 *Schema
	EmbeddedSchema         *Schema
	OwnerSchema            *Schema
	ReflectValueOf         func(context.Context, reflect.Value) reflect.Value                  // 通過反射獲取該欄位的反射物件
	ValueOf                func(context.Context, reflect.Value) (value interface{}, zero bool) // 通過反射獲取該欄位的值 get方法
	Set                    func(context.Context, reflect.Value, interface{}) error             // 通過反射設定該欄位的值 set方法
	Serializer             SerializerInterface
	NewValuePool           FieldNewValuePool
}
  • clause.Interfaceclause.Clause

gorm定義了多種clause,包括

檢視clause.Interface程式碼
// Interface clause interface
type Interface interface {
	Name() string
	Build(Builder)
	MergeClause(*Clause)
}
檢視clause.Clause程式碼
// Clause
type Clause struct {
	Name                string // WHERE
	BeforeExpression    Expression
	AfterNameExpression Expression
	AfterExpression     Expression
	Expression          Expression
	Builder             ClauseBuilder
}

 

3.2 解析Model

通過呼叫stmt.Parse(stmt.Model)進行model解析

stmt.Parse(stmt.Model)會呼叫到函數func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) 進行解析。

詳細程式碼見schema.go,下面列舉重要的幾個點。

  • 通過反射判斷dest interface{}是否為reflect.Struct
  • 通過介面獲取表名,其中stu實現了Tabler介面
	// 獲取表名
	modelValue := reflect.New(modelType)
	tableName := namer.TableName(modelType.Name())
	if tabler, ok := modelValue.Interface().(Tabler); ok {
		tableName = tabler.TableName()
	}
	if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
		tableName = tabler.TableName(namer)
	}
	if en, ok := namer.(embeddedNamer); ok {
		tableName = en.Table
	}
	if specialTableName != "" && specialTableName != tableName {
		tableName = specialTableName
	}
  • 解析model每個欄位
// 通過反射獲取每個欄位
for i := 0; i < modelType.NumField(); i++ {
    if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
        // 解析每個欄位
        if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
            schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)
        } else {
            schema.Fields = append(schema.Fields, field)
        }
    }
}
  • 放到map方便查詢,並且通過func (field *Field) setupValuerAndSetter()初始化每個Field的ReflectValueOfValueOfSet方法。
    for _, field := range schema.Fields {
        if field.DBName == "" && field.DataType != "" {
            field.DBName = namer.ColumnName(schema.Table, field.Name)
        }
        if field.DBName != "" {
            // nonexistence or shortest path or first appear prioritized if has permission
            if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
                if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
                    schema.DBNames = append(schema.DBNames, field.DBName)
                }
                // gorm tag欄位到field的對映
                schema.FieldsByDBName[field.DBName] = field
                // model 欄位到field的對映
                schema.FieldsByName[field.Name] = field
                if v != nil && v.PrimaryKey {
                    for idx, f := range schema.PrimaryFields {
                        if f == v {
                            schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
                        }
                    }
                }
                // 主鍵
                if field.PrimaryKey {
                    schema.PrimaryFields = append(schema.PrimaryFields, field)
                }
            }
        }
        if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
            schema.FieldsByName[field.Name] = field
        }
        // 掛載欄位的set方法和get方法
        field.setupValuerAndSetter()
    }

 

值得一提的是,每個model解析後的結果是一致,可以將結果解析的結構快取下來,並且通過chan來解決並行的問題。

解析model之後,通過process獲取到勾點函數及建立行的函數,具體程式碼見Github

	for _, f := range p.fns {
		f(db)
	}

 

3.3 執行勾點函數及建立行的函數

建立行的函數及對應的勾點函數位於create.go

  • 建立行記錄
if db.Statement.SQL.Len() == 0 {
    db.Statement.SQL.Grow(180)
    db.Statement.AddClauseIfNotExists(clause.Insert{})
    db.Statement.AddClause(ConvertToCreateValues(db.Statement))

    db.Statement.Build(db.Statement.BuildClauses...)
}

 這裡插入兩個clause.Clause,分別為clause.Insert以及clause.Values,然後呼叫這兩種clause.Clausebuild方法生成SQL語句。

首先,看下ConvertToCreateValues的實現,這裡只擷取部分程式碼

values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
// 獲取每一列的名字
for _, db := range stmt.Schema.DBNames {
    if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
	    if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) {
		    values.Columns = append(values.Columns, clause.Column{Name: db})
		}
	}
}

// 獲取每一列對應的值
switch stmt.ReflectValue.Kind() {
    case reflect.Slice, reflect.Array:
    case reflect.Struct:
        values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
        for idx, column := range values.Columns {
            field := stmt.Schema.FieldsByDBName[column.Name]
            // func (field *Field) setupValuerAndSetter() 掛載的方法
            if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero {
                if field.DefaultValueInterface != nil {
                    values.Values[0][idx] = field.DefaultValueInterface
                    stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface))
                } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
                    stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
                    values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
                }
            } else if field.AutoUpdateTime > 0 && updateTrackTime {
                stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
                values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
            }
        }

通過ConvertToCreateValues獲取了每一列的名稱及對應的值。

接下來,看clause.ClauseSQL語句的過程。

遍歷加入clause,此時分別為clause.Insert以及clause.Values

// Build build sql with clauses names
func (stmt *Statement) Build(clauses ...string) {
    var firstClauseWritten bool
    for _, name := range clauses {
        if c, ok := stmt.Clauses[name]; ok {
            // 程式碼有刪減
            c.Build(stmt)
        }
    }
}

接著呼叫func (c Clause) Build(builder Builder)

// Build build clause 
func (c Clause) Build(builder Builder) {
    // 有刪減
    // c為clause.Insert以及clause.Values
    if c.Name != "" {
        // builder寫入 INSERT 或者 VALUES
        builder.WriteString(c.Name)
        builder.WriteByte(' ')
    }
    // 通過clause.Insert以及clause.Values的MergeClause函數,c.Expression為clause.Insert以及clause.Values
    // 因此,這裡呼叫clause.Insert或者clause.Values的Build的方法
    c.Expression.Build(builder)
}

接下來分別看clause.Insert以及clause.Values

// Build build insert clause
func (insert Insert) Build(builder Builder) {
    // builder寫入INTO,此時builder為INSERT INTO
    builder.WriteString("INTO ")
    // builder寫入表名
    builder.WriteQuoted(currentTable)
}

從呼叫的鏈路可以得出,這裡builderstmt *Statement,並且currentTable型別為clause.Table,因此

// WriteQuoted write quoted value
func (stmt *Statement) WriteQuoted(value interface{}) {
    stmt.QuoteTo(&stmt.SQL, value)
}

// QuoteTo write quoted value to writer 有刪減
func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
    write := func(raw bool, str string) {
        // mysql驅動Dialector
        stmt.DB.Dialector.QuoteTo(writer, str)
    }
    switch v := field.(type) {
    case clause.Table:
        write(v.Raw, stmt.Table)
    }
}

至此,builder已經拼裝出INSERT INTO `t_student` ,解析來再看clause.Valuesbuild方法

// Build build from clause
func (values Values) Build(builder Builder) {
    if len(values.Columns) > 0 {
        builder.WriteByte('(')
        for idx, column := range values.Columns {
            if idx > 0 {
                builder.WriteByte(',')
            }
            builder.WriteQuoted(column)
        }
        builder.WriteByte(')')
        builder.WriteString(" VALUES ")
        for idx, value := range values.Values {
            if idx > 0 {
                builder.WriteByte(',')
            }
            builder.WriteByte('(')
            builder.AddVar(builder, value...)
            builder.WriteByte(')')
        }
    } else {
        builder.WriteString("DEFAULT VALUES")
    }
}

func (values Values) Build(builder Builder)取出所有列名和列對應的值

最終builder拼裝成例子的完整SQL語句INSERT INTO `t_student` (`age`,`height`,`weight`) VALUES (18,185,70)

有了SQL語句,就可以執行了

result, err := db.Statement.ConnPool.ExecContext(
    db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
)

通過前一面學習,db.Statement.ConnPool的值為sql.DB,實際執行的函數為func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (Result, error)

至此,從Model到DML到流程已經完成。

 

4. 將ID寫入到model的 

看返回引數sql.Result,因此通過LastInsertId() (int64, error)可以獲取到插入行的ID值。

// A Result summarizes an executed SQL command.
type Result interface {
	// LastInsertId returns the integer generated by the database
	// in response to a command. Typically this will be from an
	// "auto increment" column when inserting a new row. Not all
	// databases support this feature, and the syntax of such
	// statements varies.
	LastInsertId() (int64, error)

	// RowsAffected returns the number of rows affected by an
	// update, insert, or delete. Not every database or database
	// driver may support this.
	RowsAffected() (int64, error)
}

獲取到剛插入的行ID值,再通過反射寫入model的ID欄位即可。

db.RowsAffected, _ = result.RowsAffected()
if db.RowsAffected != 0 && db.Statement.Schema != nil &&
    db.Statement.Schema.PrioritizedPrimaryField != nil &&
    db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
    insertID, err := result.LastInsertId()
    switch db.Statement.ReflectValue.Kind() {
    case reflect.Struct:
        _, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
        if isZero {
            // 通過反射更新ID
            db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
        }
    }
}

 

5. 總結

使用反射解析Model,獲得每個成員對應的表的列名、值等資訊。

定義SQL各個關鍵詞如INSERTVALUESFROMDELETE的結構體,並實現clause.Interface介面

進而對SQL語句的構造進行抽象封裝。