Code generation in Go by the example of creating a client for the database

In this article, I would like to consider the issues of code generation in Golang. I noticed that often in the comments on articles on Go mention code generation and reflection, which causes heated debate. At the same time, there are few articles on code generation on the hub, although it is used quite a lot in projects on Go. In the article I will try to tell you what code generation is, to describe the scope of application with code examples. Also, I will not ignore the reflection.

When code generation is used


On Habré there are already good articles on the topic here and here , I will not repeat.

Code generation should be used in cases:

  • Increasing the speed of the code, that is, to replace reflection;
  • Reducing the routine of a programmer (and errors associated with it);
  • Implementation of wrappers according to the given rules.

From the examples, we can consider the Stringer library, which is included in the standard language supply and allows you to automatically generate String () methods for sets of numeric constants. Using it, you can implement the output of variable names. Examples of the library were described in detail in the above articles. The most interesting example was the derivation of the color name from the palette. The application of code generation there avoids changing the code in several places when changing the palette.

From a more practical example, we can mention the easyjson library from Mail.ru. This library allows you to speed up the execution of masrshall / unmarshall JSON from / to the structure. Their implementation on benchmarks bypassed all the alternatives. To use the library, you need to call easyjson, it will generate code for all structures that it finds in the transferred file, or only for those to which the comment // easyjson: json is indicated. Take the user structure as an example:

type User struct{
    ID int
    Login string
    Email string
    Level int
}

For the file in which it is contained, run the code generation:

easyjson -all main.go

As a result, we get methods for User:

  • MarshalEasyJSON (w * jwriter.Writer) - to convert the structure to a JSON byte array;
  • UnmarshalEasyJSON (l * jlexer.Lexer) - to convert from an array of bytes into a structure.

The functions MarshalJSON () ([] byte, error) and UnmarshalJSON (data [] byte) error are necessary for compatibility with the standard json interface.

Easyjson code

func TestEasyJSON() {
	testJSON := `{"ID":123, "Login":"TestUser", "Email":"user@gmail.com", "Level":12}`
	JSONb := []byte(testJSON)
	fmt.Println(testJSON)
	recvUser := &User{}
	recvUser.UnmarshalJSON(JSONb)
	fmt.Println(recvUser)
	recvUser.Level += 1
	outJSON, _ := recvUser.MarshalJSON()
	fmt.Println(string(outJSON))
}


In this function, we first convert JSON into a structure, add one level and print the resulting JSON. Code generation by easyjson means getting rid of runtime reflection and increasing code performance.

Code generation is actively used to create microservices that communicate via gRPC. It uses the protobuf format to describe the methods of services - using the intermediate language EDL. After the description of the service, the protoc compiler is launched, which generates code for the desired programming language. In the generated code, we get the interfaces that need to be implemented in the server and the methods that are used on the client to organize communication. It turns out quite conveniently, we can describe our services in a single format and generate code for the programming language in which each of the interaction elements will be described.

Also, code generation can be used in the development of frameworks. For example, to implement code that is not required to be written by the application developer, but it is necessary for correct operation. For example, to create form field validators, Middleware automatic generation, dynamic generation of clients to the DBMS.

Go Code Generator Implementation


Let us examine in practice how the mechanism of code generation in Go works. First of all, it is necessary to mention the AST - Abstract Syntax Tree or Abstract Syntax Tree. For details, you can go to Wikipedia . For our purposes, it is necessary to understand that the entire program is built in the form of a graph, where the vertices are mapped (marked) with the operators of the programming language, and the leaves with the corresponding operands.

So, for starters, we need packages:

go / ast
go / parser /
go / token /

Parsing the file with the code and compiling the tree is performed by the following commands


fset := token.NewFileSet()
node, err := parser.ParseFile(fset, os.Args[1], nil, parser.ParseComments)

We indicate that the file name should be taken from the first argument of the command line, we also ask that comments be added to the tree.

In general, to control code generation, the user (the developer of the code on the basis of which other code is generated) can use comments or tags (as we write `json:" "` near the structure field).

For an example, we will write a code generator for working with a database. The code generator will look at the file transferred to it, look for structures that have a corresponding comment and create a wrapper over the structure (CRUD methods) for database interaction. We will use the parameters:

  • dbe comment: {"table": "users"}, in which you can define the table in which the structure records will be;
  • dbe tag for the fields of the structure, in which you can specify the name of the column in which to put the field value and attributes for the database: primary_key and not_null. They will be used when creating the table. And for the field name, you can use "-" so as not to create a column for it.

I will make a reservation in advance that the project is not yet combat, it will not contain part of the necessary checks and protections. If there is interest, I will continue its development.

So, we have decided on the task and parameters for controlling the generation of code, we can start writing code.

Links to all code will be at the end of the article.

We start bypassing the resulting tree and will parse each element of the first level. Go has predefined types for parsing: BadDecl, GenDecl, and FuncDecl.

Type Description
// A BadDecl node is a placeholder for declarations containing
// syntax errors for which no correct declaration nodes can be
// created.
//
BadDecl struct {
    From, To token.Pos // position range of bad declaration
}
// A GenDecl node (generic declaration node) represents an import,
// constant, type or variable declaration. A valid Lparen position
// (Lparen.IsValid()) indicates a parenthesized declaration.
//
// Relationship between Tok value and Specs element type:
//
// token.IMPORT *ImportSpec
// token.CONST *ValueSpec
// token.TYPE *TypeSpec
// token.VAR *ValueSpec
//
GenDecl struct {
    Doc *CommentGroup // associated documentation; or nil
    TokPos token.Pos // position of Tok
    Tok token.Token // IMPORT, CONST, TYPE, VAR
    Lparen token.Pos // position of '(', if any
    Specs []Spec
    Rparen token.Pos // position of ')', if any
}
// A FuncDecl node represents a function declaration.
FuncDecl struct {
    Doc *CommentGroup // associated documentation; or nil
    Recv *FieldList // receiver (methods); or nil (functions)
    Name *Ident // function/method name
    Type *FuncType // function signature: parameters, results, and position of "func" keyword
    Body *BlockStmt // function body; or nil for external (non-Go) function
}


We are interested in structures, so we use GenDecl. At this stage, FuncDecl may be useful, in which the definitions of functions lie and you wrap them, but now we do not need them. Next, we look at the Specs array at each node, and look for that we are working with a type definition field (* ast.TypeSpec) and this is a structure (* ast.StructType). After we have determined that we have a structure, we check that it has a comment // dbe. The full tree traversal code and the definition of which structure to work with are below.

Tree traversal and getting structures

for _, f := range node.Decls {
	genD, ok := f.(*ast.GenDecl)
	if !ok {
		fmt.Printf("SKIP %T is not *ast.GenDecl\n", f)
		continue
	}
	targetStruct := &StructInfo{}
	var thisIsStruct bool
	for _, spec := range genD.Specs {
		currType, ok := spec.(*ast.TypeSpec)
		if !ok {
			fmt.Printf("SKIP %T is not ast.TypeSpec\n", spec)
			continue
		}

		currStruct, ok := currType.Type.(*ast.StructType)
		if !ok {
			fmt.Printf("SKIP %T is not ast.StructType\n", currStruct)
			continue
		}
		targetStruct.Name = currType.Name.Name
		thisIsStruct = true
	}
	//Getting comments
	var needCodegen bool
	var dbeParams string
	if thisIsStruct {
		for _, comment := range genD.Doc.List {
			needCodegen = needCodegen || strings.HasPrefix(comment.Text, "// dbe")
			if len(comment.Text) < 7 {
				dbeParams = ""
			} else {
				dbeParams = strings.Replace(comment.Text, "// dbe:", "", 1)
			}
		}
	}
	if needCodegen {
		targetStruct.Target = genD
		genParams := &DbeParam{}
		if len(dbeParams) != 0 {
			err := json.Unmarshal([]byte(dbeParams), genParams)
			if err != nil {
				fmt.Printf("Error encoding DBE params for structure %s\n", targetStruct.Name)
				continue
			}
		} else {
			genParams.TableName = targetStruct.Name
		}

		targetStruct.GenParam = genParams
		generateMethods(targetStruct, out)
	}
}

, :

type DbeParam struct {
	TableName string `json:"table"`
}

type StructInfo struct {
	Name     string
	GenParam *DbeParam
	Target   *ast.GenDecl
}


Now we will prepare information about the fields of the structure, so that, based on the information received, we will generate table creation functions (createTable) and CRUD methods.

Code for getting fields from a structure

func generateMethods(reqStruct *StructInfo, out *os.File) {
	for _, spec := range reqStruct.Target.Specs {
		fmt.Fprintln(out, "")
		currType, ok := spec.(*ast.TypeSpec)
		if !ok {
			continue
		}
		currStruct, ok := currType.Type.(*ast.StructType)
		if !ok {
			continue
		}

		fmt.Printf("\tgenerating createTable methods for %s\n", currType.Name.Name)

		curTable := &TableInfo{
			TableName: reqStruct.GenParam.TableName,
			Columns:   make([]*ColInfo, 0, len(currStruct.Fields.List)),
		}

		for _, field := range currStruct.Fields.List {
			if len(field.Names) == 0 {
				continue
			}
			tableCol := &ColInfo{FieldName: field.Names[0].Name}
			var fieldIsPrimKey bool
			var preventThisField bool
			if field.Tag != nil {
				tag := reflect.StructTag(field.Tag.Value[1 : len(field.Tag.Value)-1])
				tagVal := tag.Get("dbe")
				fmt.Println("dbe:", tagVal)
				tagParams := strings.Split(tagVal, ",")
			PARAMSLOOP:
				for _, param := range tagParams {
					switch param {
					case "primary_key":
						if curTable.PrimaryKey == nil {
							fieldIsPrimKey = true
							tableCol.NotNull = true
						} else {
							log.Panicf("Table %s cannot have more then 1 primary key!", currType.Name.Name)
						}
					case "not_null":
						tableCol.NotNull = true
					case "-":
						preventThisField = true
						break PARAMSLOOP
					default:
						tableCol.ColName = param
					}

				}
				if preventThisField {
					continue
				}
			}
			if tableCol.ColName == "" {
				tableCol.ColName = tableCol.FieldName
			}
			if fieldIsPrimKey {
				curTable.PrimaryKey = tableCol
			}
			//Determine field type
			var fieldType string
			switch field.Type.(type) {
			case *ast.Ident:
				fieldType = field.Type.(*ast.Ident).Name
			case *ast.SelectorExpr:
				fieldType = field.Type.(*ast.SelectorExpr).Sel.Name
			}
			//fieldType := field.Type.(*ast.Ident).Name
			fmt.Printf("%s- %s\n", tableCol.FieldName, fieldType)
			//Check for integers
			if strings.Contains(fieldType, "int") {
				tableCol.ColType = "integer"
			} else {
				//Check for other types
				switch fieldType {
				case "string":
					tableCol.ColType = "text"
				case "bool":
					tableCol.ColType = "boolean"
				case "Time":
					tableCol.ColType = "TIMESTAMP"
				default:
					log.Panicf("Field type %s not supported", fieldType)
				}
			}
			tableCol.FieldType = fieldType
			curTable.Columns = append(curTable.Columns, tableCol)
			curTable.StructName = currType.Name.Name

		}
		curTable.generateCreateTable(out)

		fmt.Printf("\tgenerating CRUD methods for %s\n", currType.Name.Name)
		curTable.generateCreate(out)
		curTable.generateQuery(out)
		curTable.generateUpdate(out)
		curTable.generateDelete(out)
	}
}


We go through all the fields of the desired structure and begin parsing the tags of each field. Using reflection, we get the tag we are interested in (after all, there may be other tags on the field, for example, for json). We analyze the contents of the tag and determine whether the field is a primary key (if more than one primary key is specified, curse about it and stop execution), is there a requirement for the field to be non-zero, do we need to work with the database for this field and define column name if it was overridden in the tag. We also need to determine the type of the table column based on the type of the structure field. There are a finite set of field types, we will generate only for basic types, we will reduce all the rows to the TEXT field type, although in general, you can add the definition of the column type to the tags so that you can configure more finely. On the other hand,no one bothers to create the desired table in the database in advance, or to correct the created automatically.

After parsing the structure, we start the method for creating the code for the table creation function and the methods for creating the Create, Query, Update, Delete functions. We prepare an SQL expression for each function and a binding to run. I didn’t bother with error handling, I just give the error from the database driver. For code generation, it is convenient to use templates from the text / template library. With their help, you can get a much more supported and predictable code (the code is visible immediately, but not smeared by the generator code).

Table creation

func (tableD *TableInfo) generateCreateTable(out *os.File) error {
	fmt.Fprint(out, "func (in *"+tableD.StructName+") createTable(db *sql.DB) (error) {\n")
	var resSQLq = fmt.Sprintf("\tsqlQ := `CREATE TABLE %s (\n", tableD.TableName)
	for _, col := range tableD.Columns {
		colSQL := col.ColName + " " + col.ColType
		if col.NotNull {
			colSQL += " NOT NULL"
		}
		if col == tableD.PrimaryKey {
			colSQL += " AUTO_INCREMENT"
		}
		colSQL += ",\n"
		resSQLq += colSQL
	}
	if tableD.PrimaryKey != nil {
		resSQLq += fmt.Sprintf("PRIMARY KEY (%s)\n", tableD.PrimaryKey.ColName)
	}
	resSQLq += ")`\n"
	fmt.Fprint(out, resSQLq)
	fmt.Fprint(out, "\t_, err := db.Exec(sqlQ)\n\t\tif err != nil {\n\t\t\treturn err\n\t\t}\n")
	fmt.Fprint(out, "\t return nil\n}\n\n")
	return nil
}


Add Record


	fmt.Fprint(out, "func (in *"+tableD.StructName+") Create(db *sql.DB) (error) {\n")
	var columns, valuePlaces, valuesListParams string
	for _, col := range tableD.Columns {
		if col == tableD.PrimaryKey {
			continue
		}
		columns += "`" + col.ColName + "`,"
		valuePlaces += "?,"
		valuesListParams += "in." + col.FieldName + ","
	}
	columns = columns[:len(columns)-1]
	valuePlaces = valuePlaces[:len(valuePlaces)-1]
	valuesListParams = valuesListParams[:len(valuesListParams)-1]

	resSQLq := fmt.Sprintf("\tsqlQ := \"INSERT INTO %s (%s) VALUES (%s);\"\n",
		tableD.TableName,
		columns,
		valuePlaces)
	fmt.Fprintln(out, resSQLq)
	fmt.Fprintf(out, "result, err := db.Exec(sqlQ, %s)\n", valuesListParams)
	fmt.Fprintln(out, `if err != nil {
		return err
	}`)
	//Setting id if we have primary key
	if tableD.PrimaryKey != nil {
		fmt.Fprintf(out, `lastId, err := result.LastInsertId()
		if err != nil {
			return nil
		}`)
		fmt.Fprintf(out, "\nin.%s = %s(lastId)\n", tableD.PrimaryKey.FieldName, tableD.PrimaryKey.FieldType)
	}
	fmt.Fprintln(out, "return nil\n}\n\n")
	//in., _ := result.LastInsertId()`)
	return nil
}


Retrieving records from a table

func (tableD *TableInfo) generateQuery(out *os.File) error {
	fmt.Fprint(out, "func (in *"+tableD.StructName+") Query(db *sql.DB) ([]*"+tableD.StructName+", error) {\n")

	fmt.Fprintf(out, "\tsqlQ := \"SELECT * FROM %s;\"\n", tableD.TableName)
	fmt.Fprintf(out, "rows, err := db.Query(sqlQ)\n")
	fmt.Fprintf(out, "results := make([]*%s, 0)\n", tableD.StructName)
	fmt.Fprintf(out, `for rows.Next() {`)
	fmt.Fprintf(out, "\t tempR := &%s{}\n", tableD.StructName)
	var valuesListParams string
	for _, col := range tableD.Columns {
		valuesListParams += "&tempR." + col.FieldName + ","
	}
	valuesListParams = valuesListParams[:len(valuesListParams)-1]

	fmt.Fprintf(out, "\terr = rows.Scan(%s)\n", valuesListParams)
	fmt.Fprintf(out, `if err != nil {
		return nil, err
		}`)
	fmt.Fprintf(out, "\n\tresults = append(results, tempR)")
	fmt.Fprintf(out, `}
		return results, nil
	}`)
	fmt.Fprintln(out, "")
	fmt.Fprintln(out, "")
	return nil
}


Record update (works by primary key)

func (tableD *TableInfo) generateUpdate(out *os.File) error {
	fmt.Fprint(out, "func (in *"+tableD.StructName+") Update(db *sql.DB) (error) {\n")
	var updVals, valuesListParams string
	for _, col := range tableD.Columns {
		if col == tableD.PrimaryKey {
			continue
		}
		updVals += "`" + col.ColName + "`=?,"
		valuesListParams += "in." + col.FieldName + ","
	}
	updVals = updVals[:len(updVals)-1]
	valuesListParams += "in." + tableD.PrimaryKey.FieldName

	resSQLq := fmt.Sprintf("\tsqlQ := \"UPDATE %s SET %s WHERE %s = ?;\"\n",
		tableD.TableName,
		updVals,
		tableD.PrimaryKey.ColName)
	fmt.Fprintln(out, resSQLq)
	fmt.Fprintf(out, "_, err := db.Exec(sqlQ, %s)\n", valuesListParams)
	fmt.Fprintln(out, `if err != nil {
		return err
	}`)

	fmt.Fprintln(out, "return nil\n}\n\n")
	//in., _ := result.LastInsertId()`)
	return nil
}


Delete a record (works by primary key)

func (tableD *TableInfo) generateDelete(out *os.File) error {
	fmt.Fprint(out, "func (in *"+tableD.StructName+") Delete(db *sql.DB) (error) {\n")
	fmt.Fprintf(out, "sqlQ := \"DELETE FROM %s WHERE id = ?\"\n", tableD.TableName)

	fmt.Fprintf(out, "_, err := db.Exec(sqlQ, in.%s)\n", tableD.PrimaryKey.FieldName)

	fmt.Fprintln(out, `if err != nil {
		return err
	}
	return nil
}`)
	fmt.Fprintln(out)
	return nil
}


The start of the resulting code generator is performed by the usual go run, we pass the path to the file for which you want to generate the code in the -name flag. As a result, we get the file with the suffix _dbe, in which the generated code lies. For tests, create methods for the following structure:


// dbe:{"table": "users"}
type User struct {
	ID       int    `dbe:"id,primary_key"`
	Login    string `dbe:"login,not_null"`
	Email    string
	Level    uint8
	IsActive bool
	UError   error `dbe:"-"`
}

The resulting code

package main

import "database/sql"

func (in *User) createTable(db *sql.DB) error {
	sqlQ := `CREATE TABLE users (
	id integer NOT NULL AUTO_INCREMENT,
	login text NOT NULL,
	Email text,
	Level integer,
	IsActive boolean,
	PRIMARY KEY (id)
	)`
	_, err := db.Exec(sqlQ)
	if err != nil {
		return err
	}
	return nil
}

func (in *User) Create(db *sql.DB) error {
	sqlQ := "INSERT INTO users (`login`,`Email`,`Level`,`IsActive`) VALUES (?,?,?,?);"

	result, err := db.Exec(sqlQ, in.Login, in.Email, in.Level, in.IsActive)
	if err != nil {
		return err
	}
	lastId, err := result.LastInsertId()
	if err != nil {
		return nil
	}
	in.ID = int(lastId)
	return nil
}

func (in *User) Query(db *sql.DB) ([]*User, error) {
	sqlQ := "SELECT * FROM users;"
	rows, err := db.Query(sqlQ)
	results := make([]*User, 0)
	for rows.Next() {
		tempR := &User{}
		err = rows.Scan(&tempR.ID, &tempR.Login, &tempR.Email, &tempR.Level, &tempR.IsActive)
		if err != nil {
			return nil, err
		}
		results = append(results, tempR)
	}
	return results, nil
}

func (in *User) Update(db *sql.DB) error {
	sqlQ := "UPDATE users SET `login`=?,`Email`=?,`Level`=?,`IsActive`=? WHERE id = ?;"

	_, err := db.Exec(sqlQ, in.Login, in.Email, in.Level, in.IsActive, in.ID)
	if err != nil {
		return err
	}
	return nil
}

func (in *User) Delete(db *sql.DB) error {
	sqlQ := "DELETE FROM users WHERE id = ?"
	_, err := db.Exec(sqlQ, in.ID)
	if err != nil {
		return err
	}
	return nil
}


To test the operation of the generated code, create an object with arbitrary data, create a table for it (if the table exists in the database, an error will be returned). After we place this object in the table, read all the fields from the table, update the level values ​​and delete the object.

Call the resulting methods

var err error
db, err := sql.Open("mysql", DSN)
if err != nil {
	fmt.Println("Unable to connect to DB", err)
	return
}
err = db.Ping()
if err != nil {
	fmt.Println("Unable to ping BD")
	return
}
newUser := &User{
	Login:    "newUser",
	Email:    "new@test.com",
	Level:    0,
	IsActive: false,
	UError:   nil,
}

err = newUser.createTable(db)
if err != nil {
	fmt.Println("Error creating table.", err)

}
err = newUser.Create(db)
if err != nil {
	fmt.Println("Error creating user.", err)
	return
}

nU := &User{}
dbUsers, err := nU.Query(db)
if err != nil {
	fmt.Println("Error selecting users.", err)
	return
}
fmt.Printf("From table users selected %d fields", len(dbUsers))
var DBUser *User
for _, user := range dbUsers {
	fmt.Println(user)
	DBUser = user
}
DBUser.Level = 2
err = DBUser.Update(db)
if err != nil {
	fmt.Println("Error updating users.", err)
	return
}
err = DBUser.Delete(db)
if err != nil {
	fmt.Println("Error deleting users.", err)
	return
}


In the current implementation, the client’s functionality to the database is very limited:

  • only MySQL is supported;
  • Not all field types are supported.
  • there are no filtering and limits for SELECT.

However, fixing bugs is already beyond the scope of parsing Go source code and generating new code based on it.

Using a code generator in such a scenario will allow you to change the fields and types of structures used in the application in only one place, there is no need to remember to make changes to the code for interacting with the database, you just need to run the code generator every time. This task could be solved with the help of reflection, but this would have affected the performance.

The source code generator and an example of the generated code posted on Github .

All Articles