通过为数据库创建客户端的示例在Go中生成代码

在本文中,我将考虑Golang中代码生成的问题。我注意到在Go文章的评论中经常提到代码生成和反射,这引起了激烈的辩论。同时,尽管在Go上的项目中使用了很多代码,但是在Hub上很少有关于代码生成的文章。在本文中,我将尝试告诉您什么是代码生成,并通过代码示例描述应用程序的范围。另外,我不会忽略反思。

使用代码生成时


在Habré上,这里这里已经有关于该主题的好文章,我不再赘述。

在以下情况下应使用代码生成:

  • 提高代码速度,即替换反射;
  • 减少程序员的例程(以及与之相关的错误);
  • 根据给定规则执行包装器。

从示例中,我们可以考虑标准语言提供程序中包含的Stringer库,该库允许您自动为数字常量集生成String()方法。使用它,可以实现变量名的输出。在以上文章中详细描述了该库的示例。最有趣的示例是从调色板中导出颜色名称。在那里使用代码生成可以避免在更改面板时在多个位置更改代码。

从一个更实际的示例中,我们可以提到Mail.ru中的easyjson库。该库使您可以加快从/到结构的masrshall / unmarshall JSON的执行。他们在基准上的实施绕过了所有替代方案。要使用该库,您需要调用easyjson,它将为它在已传输文件中找到的所有结构生成代码,或者仅为那些表示注释// easyjson:json的结构生成代码。以用户结构为例:

type User struct{
    ID int
    Login string
    Email string
    Level int
}

对于包含它的文件,运行代码生成:

easyjson -all main.go

结果,我们获得了User的方法:

  • MarshalEasyJSON(w * jwriter.Writer)-将结构转换为JSON字节数组;
  • UnmarshalEasyJSON(l * jlexer.Lexer)-从字节数组转换为结构。

为了与标准json接口兼容,必须使用函数MarshalJSON()([]字节,错误)和UnmarshalJSON(数据[]字节)错误。

Easyjson代码

func TestEasyJSON() {
	testJSON := `{"ID":123, "Login":"TestUser", "Email":"user@gmail.com", "Level":12}`
	JSONb := []byte(testJSON)
	fmt.Println(testJSON)
	recvUser := &User{}
	recvUser.UnmarshalJSON(JSONb)
	fmt.Println(recvUser)
	recvUser.Level += 1
	outJSON, _ := recvUser.MarshalJSON()
	fmt.Println(string(outJSON))
}


在此功能中,我们首先将JSON转换为结构,添加一个级别并打印结果JSON。easyjson生成代码意味着摆脱运行时反射并提高代码性能。

代码生成活跃地用于创建通过gRPC进行通信的微服务。它使用protobuf格式来描述服务方法-使用中间语言EDL。在描述服务之后,将启动协议编译器,该协议编译器会生成所需编程语言的代码。在生成的代码中,我们获得需要在服务器中实现的接口以及在客户端上用于组织通信的方法。结果很方便,我们可以用一种格式描述我们的服务,并为将描述每个交互元素的编程语言生成代码。

同样,代码生成可用于框架的开发。例如,要实现不需要由应用程序开发人员编写的代码,但是对于正确的操作而言这是必需的。例如,为了创建表单字段验证器,中间件自动生成,动态生成DBMS的客户端。

Go代码生成器实现


让我们在实践中检查Go中代码生成机制的工作原理。首先,有必要提及AST-抽象语法树或抽象语法树。有关详细信息,您可以转到Wikipedia。出于我们的目的,有必要了解整个程序是以图形的形式构建的,其中顶点是用编程语言的运算符映射(标记)的,叶是对应的操作数的映射。

因此,对于初学者而言,我们将需要以下软件包:

go / ast
go / parser /
go / token /

使用代码解析文件并编译树是通过以下命令执行的


fset := token.NewFileSet()
node, err := parser.ParseFile(fset, os.Args[1], nil, parser.ParseComments)

我们指示文件名应从命令行的第一个参数获取,我们还要求将注释添加到树中。

通常,为了控制代码的生成,用户(根据代码生成其他代码的代码开发人员)可以使用注释或标签(在结构字段附近写“ json:””时)。

例如,我们将编写一个用于生成数据库的代码生成器。代码生成器将查看传输到该文件的文件,查找具有相应注释的结构,并在该结构(CRUD方法)上创建包装器以进行数据库交互。我们将使用以下参数:

  • dbe注释:{“ table”:“ users”},您可以在其中定义结构记录所在的表;
  • dbe标签用于结构的字段,您可以在其中指定要在其中放置数据库的字段值和属性的列的名称:primary_key和not_null。创建表时将使用它们。对于字段名,可以使用“-”,以免为其创建列。

我将提前预约该项目尚未投入使用,它不会包含必要的检查和保护措施的一部分。如果有兴趣,我会继续发展。

因此,我们确定了控制代码生成的任务和参数后,就可以开始编写代码了。

所有代码的链接将在本文的结尾。

我们开始绕过结果树,然后解析第一级的每个元素。Go具有预定义的解析类型:BadDecl,GenDecl和FuncDecl。

类型说明
// A BadDecl node is a placeholder for declarations containing
// syntax errors for which no correct declaration nodes can be
// created.
//
BadDecl struct {
    From, To token.Pos // position range of bad declaration
}
// A GenDecl node (generic declaration node) represents an import,
// constant, type or variable declaration. A valid Lparen position
// (Lparen.IsValid()) indicates a parenthesized declaration.
//
// Relationship between Tok value and Specs element type:
//
// token.IMPORT *ImportSpec
// token.CONST *ValueSpec
// token.TYPE *TypeSpec
// token.VAR *ValueSpec
//
GenDecl struct {
    Doc *CommentGroup // associated documentation; or nil
    TokPos token.Pos // position of Tok
    Tok token.Token // IMPORT, CONST, TYPE, VAR
    Lparen token.Pos // position of '(', if any
    Specs []Spec
    Rparen token.Pos // position of ')', if any
}
// A FuncDecl node represents a function declaration.
FuncDecl struct {
    Doc *CommentGroup // associated documentation; or nil
    Recv *FieldList // receiver (methods); or nil (functions)
    Name *Ident // function/method name
    Type *FuncType // function signature: parameters, results, and position of "func" keyword
    Body *BlockStmt // function body; or nil for external (non-Go) function
}


我们对结构感兴趣,因此我们使用GenDecl。在这个阶段,FuncDecl可能会有用,函数的定义位于其中,您可以包装它们,但是现在我们不需要它们了。接下来,我们查看每个节点处的Specs数组,并查找是否正在使用类型定义字段(* ast.TypeSpec),这是一个结构(* ast.StructType)。确定结构后,我们检查其是否具有注释// dbe。完整的树遍历代码和使用的结构的定义如下。

遍历树并获取结构

for _, f := range node.Decls {
	genD, ok := f.(*ast.GenDecl)
	if !ok {
		fmt.Printf("SKIP %T is not *ast.GenDecl\n", f)
		continue
	}
	targetStruct := &StructInfo{}
	var thisIsStruct bool
	for _, spec := range genD.Specs {
		currType, ok := spec.(*ast.TypeSpec)
		if !ok {
			fmt.Printf("SKIP %T is not ast.TypeSpec\n", spec)
			continue
		}

		currStruct, ok := currType.Type.(*ast.StructType)
		if !ok {
			fmt.Printf("SKIP %T is not ast.StructType\n", currStruct)
			continue
		}
		targetStruct.Name = currType.Name.Name
		thisIsStruct = true
	}
	//Getting comments
	var needCodegen bool
	var dbeParams string
	if thisIsStruct {
		for _, comment := range genD.Doc.List {
			needCodegen = needCodegen || strings.HasPrefix(comment.Text, "// dbe")
			if len(comment.Text) < 7 {
				dbeParams = ""
			} else {
				dbeParams = strings.Replace(comment.Text, "// dbe:", "", 1)
			}
		}
	}
	if needCodegen {
		targetStruct.Target = genD
		genParams := &DbeParam{}
		if len(dbeParams) != 0 {
			err := json.Unmarshal([]byte(dbeParams), genParams)
			if err != nil {
				fmt.Printf("Error encoding DBE params for structure %s\n", targetStruct.Name)
				continue
			}
		} else {
			genParams.TableName = targetStruct.Name
		}

		targetStruct.GenParam = genParams
		generateMethods(targetStruct, out)
	}
}

, :

type DbeParam struct {
	TableName string `json:"table"`
}

type StructInfo struct {
	Name     string
	GenParam *DbeParam
	Target   *ast.GenDecl
}


现在,我们将准备有关结构字段的信息,以便根据收到的信息生成表创建函数(createTable)和CRUD方法。

用于从结构获取字段的代码

func generateMethods(reqStruct *StructInfo, out *os.File) {
	for _, spec := range reqStruct.Target.Specs {
		fmt.Fprintln(out, "")
		currType, ok := spec.(*ast.TypeSpec)
		if !ok {
			continue
		}
		currStruct, ok := currType.Type.(*ast.StructType)
		if !ok {
			continue
		}

		fmt.Printf("\tgenerating createTable methods for %s\n", currType.Name.Name)

		curTable := &TableInfo{
			TableName: reqStruct.GenParam.TableName,
			Columns:   make([]*ColInfo, 0, len(currStruct.Fields.List)),
		}

		for _, field := range currStruct.Fields.List {
			if len(field.Names) == 0 {
				continue
			}
			tableCol := &ColInfo{FieldName: field.Names[0].Name}
			var fieldIsPrimKey bool
			var preventThisField bool
			if field.Tag != nil {
				tag := reflect.StructTag(field.Tag.Value[1 : len(field.Tag.Value)-1])
				tagVal := tag.Get("dbe")
				fmt.Println("dbe:", tagVal)
				tagParams := strings.Split(tagVal, ",")
			PARAMSLOOP:
				for _, param := range tagParams {
					switch param {
					case "primary_key":
						if curTable.PrimaryKey == nil {
							fieldIsPrimKey = true
							tableCol.NotNull = true
						} else {
							log.Panicf("Table %s cannot have more then 1 primary key!", currType.Name.Name)
						}
					case "not_null":
						tableCol.NotNull = true
					case "-":
						preventThisField = true
						break PARAMSLOOP
					default:
						tableCol.ColName = param
					}

				}
				if preventThisField {
					continue
				}
			}
			if tableCol.ColName == "" {
				tableCol.ColName = tableCol.FieldName
			}
			if fieldIsPrimKey {
				curTable.PrimaryKey = tableCol
			}
			//Determine field type
			var fieldType string
			switch field.Type.(type) {
			case *ast.Ident:
				fieldType = field.Type.(*ast.Ident).Name
			case *ast.SelectorExpr:
				fieldType = field.Type.(*ast.SelectorExpr).Sel.Name
			}
			//fieldType := field.Type.(*ast.Ident).Name
			fmt.Printf("%s- %s\n", tableCol.FieldName, fieldType)
			//Check for integers
			if strings.Contains(fieldType, "int") {
				tableCol.ColType = "integer"
			} else {
				//Check for other types
				switch fieldType {
				case "string":
					tableCol.ColType = "text"
				case "bool":
					tableCol.ColType = "boolean"
				case "Time":
					tableCol.ColType = "TIMESTAMP"
				default:
					log.Panicf("Field type %s not supported", fieldType)
				}
			}
			tableCol.FieldType = fieldType
			curTable.Columns = append(curTable.Columns, tableCol)
			curTable.StructName = currType.Name.Name

		}
		curTable.generateCreateTable(out)

		fmt.Printf("\tgenerating CRUD methods for %s\n", currType.Name.Name)
		curTable.generateCreate(out)
		curTable.generateQuery(out)
		curTable.generateUpdate(out)
		curTable.generateDelete(out)
	}
}


我们遍历所需结构的所有字段,然后开始解析每个字段的标签。使用反射,我们得到我们感兴趣的标签(毕竟,字段上可能还有其他标签,例如json)。我们分析标签的内容,并确定该字段是否为主键(如果指定了多个主键,请对其进行诅咒并停止执行),是否要求该字段为非零值,是否需要使用该字段的数据库并定义列名称(如果在标记中被覆盖)。我们还需要根据结构字段的类型确定表列的类型。字段类型有限,我们只会为基本类型生成字段,将所有行都减少为TEXT字段类型,尽管通常可以将列类型的定义添加到标签中,以便进行更精细的配置。另一方面,没有人愿意事先在数据库中创建所需的表,或者自动更正所创建的表。

解析结构后,我们开始为表创建函数创建代码的方法以及创建Create,Query,Update,Delete函数的方法。我们为每个函数准备一个SQL表达式,并准备要运行的绑定。我没有理会错误处理,只是从数据库驱动程序中给出了错误。对于代码生成,使用文本/模板库中的模板很方便。在他们的帮助下,您可以获得更多受支持的和可预测的代码(该代码立即可见,而不会被生成器代码弄脏)。

表格创建

func (tableD *TableInfo) generateCreateTable(out *os.File) error {
	fmt.Fprint(out, "func (in *"+tableD.StructName+") createTable(db *sql.DB) (error) {\n")
	var resSQLq = fmt.Sprintf("\tsqlQ := `CREATE TABLE %s (\n", tableD.TableName)
	for _, col := range tableD.Columns {
		colSQL := col.ColName + " " + col.ColType
		if col.NotNull {
			colSQL += " NOT NULL"
		}
		if col == tableD.PrimaryKey {
			colSQL += " AUTO_INCREMENT"
		}
		colSQL += ",\n"
		resSQLq += colSQL
	}
	if tableD.PrimaryKey != nil {
		resSQLq += fmt.Sprintf("PRIMARY KEY (%s)\n", tableD.PrimaryKey.ColName)
	}
	resSQLq += ")`\n"
	fmt.Fprint(out, resSQLq)
	fmt.Fprint(out, "\t_, err := db.Exec(sqlQ)\n\t\tif err != nil {\n\t\t\treturn err\n\t\t}\n")
	fmt.Fprint(out, "\t return nil\n}\n\n")
	return nil
}


添加记录


	fmt.Fprint(out, "func (in *"+tableD.StructName+") Create(db *sql.DB) (error) {\n")
	var columns, valuePlaces, valuesListParams string
	for _, col := range tableD.Columns {
		if col == tableD.PrimaryKey {
			continue
		}
		columns += "`" + col.ColName + "`,"
		valuePlaces += "?,"
		valuesListParams += "in." + col.FieldName + ","
	}
	columns = columns[:len(columns)-1]
	valuePlaces = valuePlaces[:len(valuePlaces)-1]
	valuesListParams = valuesListParams[:len(valuesListParams)-1]

	resSQLq := fmt.Sprintf("\tsqlQ := \"INSERT INTO %s (%s) VALUES (%s);\"\n",
		tableD.TableName,
		columns,
		valuePlaces)
	fmt.Fprintln(out, resSQLq)
	fmt.Fprintf(out, "result, err := db.Exec(sqlQ, %s)\n", valuesListParams)
	fmt.Fprintln(out, `if err != nil {
		return err
	}`)
	//Setting id if we have primary key
	if tableD.PrimaryKey != nil {
		fmt.Fprintf(out, `lastId, err := result.LastInsertId()
		if err != nil {
			return nil
		}`)
		fmt.Fprintf(out, "\nin.%s = %s(lastId)\n", tableD.PrimaryKey.FieldName, tableD.PrimaryKey.FieldType)
	}
	fmt.Fprintln(out, "return nil\n}\n\n")
	//in., _ := result.LastInsertId()`)
	return nil
}


从表中检索记录

func (tableD *TableInfo) generateQuery(out *os.File) error {
	fmt.Fprint(out, "func (in *"+tableD.StructName+") Query(db *sql.DB) ([]*"+tableD.StructName+", error) {\n")

	fmt.Fprintf(out, "\tsqlQ := \"SELECT * FROM %s;\"\n", tableD.TableName)
	fmt.Fprintf(out, "rows, err := db.Query(sqlQ)\n")
	fmt.Fprintf(out, "results := make([]*%s, 0)\n", tableD.StructName)
	fmt.Fprintf(out, `for rows.Next() {`)
	fmt.Fprintf(out, "\t tempR := &%s{}\n", tableD.StructName)
	var valuesListParams string
	for _, col := range tableD.Columns {
		valuesListParams += "&tempR." + col.FieldName + ","
	}
	valuesListParams = valuesListParams[:len(valuesListParams)-1]

	fmt.Fprintf(out, "\terr = rows.Scan(%s)\n", valuesListParams)
	fmt.Fprintf(out, `if err != nil {
		return nil, err
		}`)
	fmt.Fprintf(out, "\n\tresults = append(results, tempR)")
	fmt.Fprintf(out, `}
		return results, nil
	}`)
	fmt.Fprintln(out, "")
	fmt.Fprintln(out, "")
	return nil
}


记录更新(通过主键工作)

func (tableD *TableInfo) generateUpdate(out *os.File) error {
	fmt.Fprint(out, "func (in *"+tableD.StructName+") Update(db *sql.DB) (error) {\n")
	var updVals, valuesListParams string
	for _, col := range tableD.Columns {
		if col == tableD.PrimaryKey {
			continue
		}
		updVals += "`" + col.ColName + "`=?,"
		valuesListParams += "in." + col.FieldName + ","
	}
	updVals = updVals[:len(updVals)-1]
	valuesListParams += "in." + tableD.PrimaryKey.FieldName

	resSQLq := fmt.Sprintf("\tsqlQ := \"UPDATE %s SET %s WHERE %s = ?;\"\n",
		tableD.TableName,
		updVals,
		tableD.PrimaryKey.ColName)
	fmt.Fprintln(out, resSQLq)
	fmt.Fprintf(out, "_, err := db.Exec(sqlQ, %s)\n", valuesListParams)
	fmt.Fprintln(out, `if err != nil {
		return err
	}`)

	fmt.Fprintln(out, "return nil\n}\n\n")
	//in., _ := result.LastInsertId()`)
	return nil
}


删除记录(通过主键工作)

func (tableD *TableInfo) generateDelete(out *os.File) error {
	fmt.Fprint(out, "func (in *"+tableD.StructName+") Delete(db *sql.DB) (error) {\n")
	fmt.Fprintf(out, "sqlQ := \"DELETE FROM %s WHERE id = ?\"\n", tableD.TableName)

	fmt.Fprintf(out, "_, err := db.Exec(sqlQ, in.%s)\n", tableD.PrimaryKey.FieldName)

	fmt.Fprintln(out, `if err != nil {
		return err
	}
	return nil
}`)
	fmt.Fprintln(out)
	return nil
}


生成的代码生成器的启动是通过通常的go run执行的,我们在-name标志中将路径传递到要为其生成代码的文件。结果,我们获得了后缀_dbe的文件,生成的代码位于该文件中。对于测试,请为以下结构创建方法:


// dbe:{"table": "users"}
type User struct {
	ID       int    `dbe:"id,primary_key"`
	Login    string `dbe:"login,not_null"`
	Email    string
	Level    uint8
	IsActive bool
	UError   error `dbe:"-"`
}

结果代码

package main

import "database/sql"

func (in *User) createTable(db *sql.DB) error {
	sqlQ := `CREATE TABLE users (
	id integer NOT NULL AUTO_INCREMENT,
	login text NOT NULL,
	Email text,
	Level integer,
	IsActive boolean,
	PRIMARY KEY (id)
	)`
	_, err := db.Exec(sqlQ)
	if err != nil {
		return err
	}
	return nil
}

func (in *User) Create(db *sql.DB) error {
	sqlQ := "INSERT INTO users (`login`,`Email`,`Level`,`IsActive`) VALUES (?,?,?,?);"

	result, err := db.Exec(sqlQ, in.Login, in.Email, in.Level, in.IsActive)
	if err != nil {
		return err
	}
	lastId, err := result.LastInsertId()
	if err != nil {
		return nil
	}
	in.ID = int(lastId)
	return nil
}

func (in *User) Query(db *sql.DB) ([]*User, error) {
	sqlQ := "SELECT * FROM users;"
	rows, err := db.Query(sqlQ)
	results := make([]*User, 0)
	for rows.Next() {
		tempR := &User{}
		err = rows.Scan(&tempR.ID, &tempR.Login, &tempR.Email, &tempR.Level, &tempR.IsActive)
		if err != nil {
			return nil, err
		}
		results = append(results, tempR)
	}
	return results, nil
}

func (in *User) Update(db *sql.DB) error {
	sqlQ := "UPDATE users SET `login`=?,`Email`=?,`Level`=?,`IsActive`=? WHERE id = ?;"

	_, err := db.Exec(sqlQ, in.Login, in.Email, in.Level, in.IsActive, in.ID)
	if err != nil {
		return err
	}
	return nil
}

func (in *User) Delete(db *sql.DB) error {
	sqlQ := "DELETE FROM users WHERE id = ?"
	_, err := db.Exec(sqlQ, in.ID)
	if err != nil {
		return err
	}
	return nil
}


要测试所生成代码的操作,请创建具有任意数据的对象,并为其创建一个表(如果该表存在于数据库中,则将返回错误)。将这个对象放在表中之后,从表中读取所有字段,更新级别值并删除对象。

调用结果方法

var err error
db, err := sql.Open("mysql", DSN)
if err != nil {
	fmt.Println("Unable to connect to DB", err)
	return
}
err = db.Ping()
if err != nil {
	fmt.Println("Unable to ping BD")
	return
}
newUser := &User{
	Login:    "newUser",
	Email:    "new@test.com",
	Level:    0,
	IsActive: false,
	UError:   nil,
}

err = newUser.createTable(db)
if err != nil {
	fmt.Println("Error creating table.", err)

}
err = newUser.Create(db)
if err != nil {
	fmt.Println("Error creating user.", err)
	return
}

nU := &User{}
dbUsers, err := nU.Query(db)
if err != nil {
	fmt.Println("Error selecting users.", err)
	return
}
fmt.Printf("From table users selected %d fields", len(dbUsers))
var DBUser *User
for _, user := range dbUsers {
	fmt.Println(user)
	DBUser = user
}
DBUser.Level = 2
err = DBUser.Update(db)
if err != nil {
	fmt.Println("Error updating users.", err)
	return
}
err = DBUser.Delete(db)
if err != nil {
	fmt.Println("Error deleting users.", err)
	return
}


在当前的实现中,客户端对数据库的功能非常有限:

  • 仅支持MySQL;
  • 并非所有字段类型都受支持。
  • SELECT没有过滤和限制。

但是,修复错误已经超出了解析Go源代码并基于它生成新代码的范围。

在这种情况下使用代码生成器将使您可以仅在一个地方更改应用程序中使用的字段和结构类型,无需记住为与数据库交互而对代码进行更改,您只需要每次运行代码生成器即可。可以借助反射来解决此任务,但这会影响性能。

源代码生成器以及在Github上发布的生成代码的示例

All Articles