Gorm原始碼學習系列
此文是Gorm原始碼學習系列的第二篇,主要梳理下通過Gorm建立表的流程。
gorm提供了以下幾個介面來建立行記錄
func (db *DB) Create(value interface{}) (tx *DB)
func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB)
func (db *DB) Save(value interface{}) (tx *DB)
詳細請看教學及原始碼finisher_api.go,這裡使用func (db *DB) Create(value interface{}) (tx *DB)
來說明建立行記錄等大致流程。
type Stu struct {
ID int64 `gorm:"column:id; primary_key" json:"id"`
Age int64 `gorm:"column:age;"`
Height int64 `gorm:"column:height;"`
Weight int64 `gorm:"column:weight;"`
}
// 設定表名
func (Stu) TableName() string {
return "t_student"
}
模型程式碼的主要用途如下,
gorm
標籤指定每個字斷對於的表的列名Tabler
介面指定了固定的表名,介面定義如下type Tabler interface {
TableName() string
}
關於模型定義中更多的約定和約束等,請看教學。
出於分表等業務場景,我們並不希望固定模型等表名,gorm提供了func (db *DB) Table(name string, args ...interface{}) (tx *DB)
等方法
來動態指定表名,詳情請看教學。
func main() {
// 資料庫連線, 具體檢視https://www.cnblogs.com/amos01/p/16890747.html 連線資料庫程式碼範例
db, _ := dbOpen()
// 開啟偵錯模式、會列印DML
db = db.Debug()
stu := &Stu{
Age: 18,
Height: 185,
Weight: 70,
}
db = db.Create(stu)
fmt.Printf("Error:%v ID:%v RowsAffected:%v\n", db.Error, stu.ID, db.RowsAffected)
}
程式碼輸出如下
$ go run main.go
2022/12/11 14:59:59 /Users/zbw/workspace/test/main.go:33
[1.910ms] [rows:1] INSERT INTO `t_student` (`age`,`height`,`weight`) VALUES (18,185,70)
Error:<nil> ID:1027 RowsAffected:1
從程式碼輸出可以看,行記錄的ID為1027,連線資料庫查詢,結果如下。
mysql> select * from t_student where id = 1027\G
*************************** 1. row ***************************
id: 1027
age: 18
height: 185
weight: 70
1 row in set (0.01 sec)
因此,我們帶著以下問題來梳理下Gorm建立行記錄的流程
func (db *DB) Create(value interface{}) (tx *DB)
的實現如下
// Create inserts value, returning the inserted data's primary key in value's id
func (db *DB) Create(value interface{}) (tx *DB) {
if db.CreateBatchSize > 0 {
return db.CreateInBatches(value, db.CreateBatchSize)
}
tx = db.getInstance()
tx.Statement.Dest = value
return tx.callbacks.Create().Execute(tx)
}
func (p *processor) Execute(db *DB) *DB
的實現比較長,具體程式碼見github
總結下來,做了兩件主要的事情,
X
gorm.Statement
// Statement statement
type Statement struct {
*DB
TableExpr *clause.Expr
Table string // 表名
Model interface{} // model定義
Unscoped bool
Dest interface{} // model的另外一種表達,如map
ReflectValue reflect.Value
Clauses map[string]clause.Clause
BuildClauses []string
Distinct bool
Selects []string // selected columns
Omits []string // omit columns
Joins []join
Preloads map[string][]interface{}
Settings sync.Map
ConnPool ConnPool // 資料庫連線
Schema *schema.Schema // 表結構化資訊
Context context.Context
RaiseErrorOnNotFound bool
SkipHooks bool
SQL strings.Builder // 最終的DML語句
Vars []interface{} // DML語句的引數值
CurDestIndex int // 批次建立/更新時,gorm當前操作的陣列/slice的下標
attrs []interface{}
assigns []interface{}
scopes []func(*DB) *DB
}
schema.Schem
type Schema struct {
Name string
ModelType reflect.Type
Table string // 表名
PrioritizedPrimaryField *Field
DBNames []string // 表每列的名字
PrimaryFields []*Field
PrimaryFieldDBNames []string // 表的主鍵列明
Fields []*Field // gorm自定義的model每個字短
FieldsByName map[string]*Field
FieldsByDBName map[string]*Field
FieldsWithDefaultDBValue []*Field // fields with default value assigned by database
Relationships Relationships
CreateClauses []clause.Interface // 建立行的子句
QueryClauses []clause.Interface
UpdateClauses []clause.Interface
DeleteClauses []clause.Interface
BeforeCreate, AfterCreate bool
BeforeUpdate, AfterUpdate bool
BeforeDelete, AfterDelete bool
BeforeSave, AfterSave bool
AfterFind bool
err error
initialized chan struct{}
namer Namer
cacheStore *sync.Map
}
schema.Field
// Field is the representation of model schema's field
type Field struct {
Name string // model的欄位名
DBName string // 對應表的列名
BindNames []string
DataType DataType
GORMDataType DataType
PrimaryKey bool
AutoIncrement bool
AutoIncrementIncrement int64
Creatable bool
Updatable bool
Readable bool
AutoCreateTime TimeType
AutoUpdateTime TimeType
HasDefaultValue bool
DefaultValue string
DefaultValueInterface interface{}
NotNull bool
Unique bool
Comment string
Size int
Precision int
Scale int
IgnoreMigration bool
FieldType reflect.Type // 反射型別
IndirectFieldType reflect.Type // 反射型別
StructField reflect.StructField // model欄位資訊
Tag reflect.StructTag // tag
TagSettings map[string]string
Schema *Schema
EmbeddedSchema *Schema
OwnerSchema *Schema
ReflectValueOf func(context.Context, reflect.Value) reflect.Value // 通過反射獲取該欄位的反射物件
ValueOf func(context.Context, reflect.Value) (value interface{}, zero bool) // 通過反射獲取該欄位的值 get方法
Set func(context.Context, reflect.Value, interface{}) error // 通過反射設定該欄位的值 set方法
Serializer SerializerInterface
NewValuePool FieldNewValuePool
}
clause.Interface
及clause.Clause
gorm定義了多種clause,包括
// Interface clause interface
type Interface interface {
Name() string
Build(Builder)
MergeClause(*Clause)
}
// Clause
type Clause struct {
Name string // WHERE
BeforeExpression Expression
AfterNameExpression Expression
AfterExpression Expression
Expression Expression
Builder ClauseBuilder
}
通過呼叫stmt.Parse(stmt.Model)
進行model解析
stmt.Parse(stmt.Model)
會呼叫到函數func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error)
進行解析。
詳細程式碼見schema.go,下面列舉重要的幾個點。
dest interface{}
是否為reflect.Struct
Tabler
介面 // 獲取表名
modelValue := reflect.New(modelType)
tableName := namer.TableName(modelType.Name())
if tabler, ok := modelValue.Interface().(Tabler); ok {
tableName = tabler.TableName()
}
if tabler, ok := modelValue.Interface().(TablerWithNamer); ok {
tableName = tabler.TableName(namer)
}
if en, ok := namer.(embeddedNamer); ok {
tableName = en.Table
}
if specialTableName != "" && specialTableName != tableName {
tableName = specialTableName
}
// 通過反射獲取每個欄位
for i := 0; i < modelType.NumField(); i++ {
if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
// 解析每個欄位
if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)
} else {
schema.Fields = append(schema.Fields, field)
}
}
}
func (field *Field) setupValuerAndSetter()
初始化每個Field的ReflectValueOf
、ValueOf
、Set
方法。
for _, field := range schema.Fields {
if field.DBName == "" && field.DataType != "" {
field.DBName = namer.ColumnName(schema.Table, field.Name)
}
if field.DBName != "" {
// nonexistence or shortest path or first appear prioritized if has permission
if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
schema.DBNames = append(schema.DBNames, field.DBName)
}
// gorm tag欄位到field的對映
schema.FieldsByDBName[field.DBName] = field
// model 欄位到field的對映
schema.FieldsByName[field.Name] = field
if v != nil && v.PrimaryKey {
for idx, f := range schema.PrimaryFields {
if f == v {
schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
}
}
}
// 主鍵
if field.PrimaryKey {
schema.PrimaryFields = append(schema.PrimaryFields, field)
}
}
}
if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
schema.FieldsByName[field.Name] = field
}
// 掛載欄位的set方法和get方法
field.setupValuerAndSetter()
}
值得一提的是,每個model解析後的結果是一致,可以將結果解析的結構快取下來,並且通過chan
來解決並行的問題。
解析model之後,通過process
獲取到勾點函數及建立行的函數,具體程式碼見Github
for _, f := range p.fns {
f(db)
}
建立行的函數及對應的勾點函數位於create.go
if db.Statement.SQL.Len() == 0 {
db.Statement.SQL.Grow(180)
db.Statement.AddClauseIfNotExists(clause.Insert{})
db.Statement.AddClause(ConvertToCreateValues(db.Statement))
db.Statement.Build(db.Statement.BuildClauses...)
}
這裡插入兩個clause.Clause
,分別為clause.Insert
以及clause.Values
,然後呼叫這兩種clause.Clause
的build
方法生成SQL
語句。
首先,看下ConvertToCreateValues
的實現,這裡只擷取部分程式碼
values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))}
// 獲取每一列的名字
for _, db := range stmt.Schema.DBNames {
if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil {
if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) {
values.Columns = append(values.Columns, clause.Column{Name: db})
}
}
}
// 獲取每一列對應的值
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
case reflect.Struct:
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name]
// func (field *Field) setupValuerAndSetter() 掛載的方法
if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero {
if field.DefaultValueInterface != nil {
values.Values[0][idx] = field.DefaultValueInterface
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface))
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
}
} else if field.AutoUpdateTime > 0 && updateTrackTime {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime))
values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue)
}
}
通過ConvertToCreateValues
獲取了每一列的名稱及對應的值。
接下來,看clause.Clause
到SQL
語句的過程。
遍歷加入clause
,此時分別為clause.Insert
以及clause.Values
// Build build sql with clauses names
func (stmt *Statement) Build(clauses ...string) {
var firstClauseWritten bool
for _, name := range clauses {
if c, ok := stmt.Clauses[name]; ok {
// 程式碼有刪減
c.Build(stmt)
}
}
}
接著呼叫func (c Clause) Build(builder Builder)
// Build build clause
func (c Clause) Build(builder Builder) {
// 有刪減
// c為clause.Insert以及clause.Values
if c.Name != "" {
// builder寫入 INSERT 或者 VALUES
builder.WriteString(c.Name)
builder.WriteByte(' ')
}
// 通過clause.Insert以及clause.Values的MergeClause函數,c.Expression為clause.Insert以及clause.Values
// 因此,這裡呼叫clause.Insert或者clause.Values的Build的方法
c.Expression.Build(builder)
}
接下來分別看clause.Insert
以及clause.Values
// Build build insert clause
func (insert Insert) Build(builder Builder) {
// builder寫入INTO,此時builder為INSERT INTO
builder.WriteString("INTO ")
// builder寫入表名
builder.WriteQuoted(currentTable)
}
從呼叫的鏈路可以得出,這裡builder
為stmt *Statement
,並且currentTable
型別為clause.Table
,因此
// WriteQuoted write quoted value
func (stmt *Statement) WriteQuoted(value interface{}) {
stmt.QuoteTo(&stmt.SQL, value)
}
// QuoteTo write quoted value to writer 有刪減
func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
write := func(raw bool, str string) {
// mysql驅動Dialector
stmt.DB.Dialector.QuoteTo(writer, str)
}
switch v := field.(type) {
case clause.Table:
write(v.Raw, stmt.Table)
}
}
至此,builder
已經拼裝出INSERT INTO `t_student`
,解析來再看clause.Values
的build
方法
// Build build from clause
func (values Values) Build(builder Builder) {
if len(values.Columns) > 0 {
builder.WriteByte('(')
for idx, column := range values.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
builder.WriteByte(')')
builder.WriteString(" VALUES ")
for idx, value := range values.Values {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteByte('(')
builder.AddVar(builder, value...)
builder.WriteByte(')')
}
} else {
builder.WriteString("DEFAULT VALUES")
}
}
func (values Values) Build(builder Builder)
取出所有列名和列對應的值
最終builder
拼裝成例子的完整SQL語句INSERT INTO `t_student` (`age`,`height`,`weight`) VALUES (18,185,70)
有了SQL語句,就可以執行了
result, err := db.Statement.ConnPool.ExecContext(
db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...,
)
通過前一面學習,db.Statement.ConnPool
的值為sql.DB
,實際執行的函數為func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (Result, error)
至此,從Model到DML到流程已經完成。
看返回引數sql.Result
,因此通過LastInsertId() (int64, error)
可以獲取到插入行的ID值。
// A Result summarizes an executed SQL command.
type Result interface {
// LastInsertId returns the integer generated by the database
// in response to a command. Typically this will be from an
// "auto increment" column when inserting a new row. Not all
// databases support this feature, and the syntax of such
// statements varies.
LastInsertId() (int64, error)
// RowsAffected returns the number of rows affected by an
// update, insert, or delete. Not every database or database
// driver may support this.
RowsAffected() (int64, error)
}
獲取到剛插入的行ID值,再通過反射寫入model的ID欄位即可。
db.RowsAffected, _ = result.RowsAffected()
if db.RowsAffected != 0 && db.Statement.Schema != nil &&
db.Statement.Schema.PrioritizedPrimaryField != nil &&
db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
insertID, err := result.LastInsertId()
switch db.Statement.ReflectValue.Kind() {
case reflect.Struct:
_, isZero := db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
if isZero {
// 通過反射更新ID
db.AddError(db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID))
}
}
}
使用反射解析Model,獲得每個成員對應的表的列名、值等資訊。
定義SQL各個關鍵詞如INSERT
、VALUES
、FROM
、DELETE
的結構體,並實現clause.Interface
介面
進而對SQL語句的構造進行抽象封裝。