This commit is contained in:
bloeys
2022-10-27 05:36:02 +04:00
parent 69888d4d90
commit d389e76b8d
2 changed files with 138 additions and 39 deletions

View File

@ -5,10 +5,13 @@ import (
"go/ast"
"go/format"
"go/token"
"io"
"os"
"strings"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/packages"
"golang.org/x/tools/imports"
)
func printDebugInfo() {
@ -32,7 +35,12 @@ func main() {
panic(err)
}
// Parse package
genCogoFuncs(cwd)
genHasGenChecksOnOriginalFuncs(cwd)
}
func genCogoFuncs(cwd string) {
pkgs, err := packages.Load(&packages.Config{
Dir: cwd,
Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax,
@ -42,22 +50,20 @@ func main() {
panic(err)
}
if len(pkgs) != 1 {
panic(fmt.Sprintf("expected to find one package but found %d", len(pkgs)))
}
// if len(pkgs) != 1 {
// panic(fmt.Sprintf("expected to find one package but found %d", len(pkgs)))
// }
processPkg(pkgs[0])
}
for _, pkg := range pkgs {
func processPkg(pkg *packages.Package) {
p := processor{
p := &processor{
fset: pkg.Fset,
funcDeclsToWrite: []*ast.FuncDecl{},
}
for i, synFile := range pkg.Syntax {
pkg.Syntax[i] = astutil.Apply(synFile, p.processDeclNode, nil).(*ast.File)
pkg.Syntax[i] = astutil.Apply(synFile, p.genCogoFuncsNodeProcessor, nil).(*ast.File)
if len(p.funcDeclsToWrite) > 0 {
@ -75,15 +81,75 @@ func processPkg(pkg *packages.Package) {
})
}
// imports.Process()
err := format.Node(os.Stdout, pkg.Fset, root)
origFName := pkg.Fset.File(synFile.Pos()).Name()
newFName := strings.Split(origFName, ".")[0] + ".cogo.go"
writeAst(newFName, pkg.Fset, root)
p.funcDeclsToWrite = p.funcDeclsToWrite[:0]
}
}
}
}
func writeAst(fName string, fset *token.FileSet, node any) {
f, err := os.Create(fName)
if err != nil {
panic("Failed to create file to write new AST. Err: " + err.Error())
}
defer f.Close()
err = format.Node(f, fset, node)
if err != nil {
panic(err.Error())
}
b, err := imports.Process(fName, nil, nil)
if err != nil {
panic("Failed to process imports on file " + fName + ". Err: " + err.Error())
}
f.Seek(0, io.SeekStart)
_, err = f.Write(b)
if err != nil {
panic(err.Error())
}
}
func genHasGenChecksOnOriginalFuncs(cwd string) {
pkgs, err := packages.Load(&packages.Config{
Dir: cwd,
Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax,
Tests: false,
})
if err != nil {
panic(err)
}
// if len(pkgs) != 1 {
// panic(fmt.Sprintf("expected to find one package but found %d", len(pkgs)))
// }
for _, pkg := range pkgs {
p := &processor{
fset: pkg.Fset,
funcDeclsToWrite: []*ast.FuncDecl{},
}
for i, synFile := range pkg.Syntax {
pkg.Syntax[i] = astutil.Apply(synFile, p.genHasGenChecksOnOriginalFuncsNodeProcessor, nil).(*ast.File)
if len(p.funcDeclsToWrite) > 0 {
origFName := pkg.Fset.File(synFile.Pos()).Name()
writeAst(origFName, pkg.Fset, pkg.Syntax[i])
p.funcDeclsToWrite = p.funcDeclsToWrite[:0]
}
}
}
}
@ -92,7 +158,7 @@ type processor struct {
funcDeclsToWrite []*ast.FuncDecl
}
func (p *processor) processDeclNode(c *astutil.Cursor) bool {
func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
n := c.Node()
if n == nil {
@ -226,6 +292,45 @@ func (p *processor) processDeclNode(c *astutil.Cursor) bool {
return true
}
func (p *processor) genHasGenChecksOnOriginalFuncsNodeProcessor(c *astutil.Cursor) bool {
n := c.Node()
if n == nil {
return false
}
funcDecl, ok := n.(*ast.FuncDecl)
if !ok || funcDecl.Body == nil {
return true
}
coroutineParamName := getCoroutineParamNameFromFuncDecl(funcDecl)
if coroutineParamName == "" || !funcDeclCallsCoroutineBegin(funcDecl, coroutineParamName) {
return false
}
for i, stmt := range funcDecl.Body.List {
// If one already exists update it and return
ifStmt, ok := stmt.(*ast.IfStmt)
if ok && ifStmtIsHasGen(ifStmt) {
funcDecl.Body.List[i] = createHasGenIfStmt(funcDecl, coroutineParamName)
p.funcDeclsToWrite = append(p.funcDeclsToWrite, funcDecl)
return true
}
}
// If the check doesn't exist add it to the beginning of the function
origList := funcDecl.Body.List
funcDecl.Body.List = make([]ast.Stmt, 0, len(origList)+1)
funcDecl.Body.List = append(funcDecl.Body.List, createHasGenIfStmt(funcDecl, coroutineParamName))
funcDecl.Body.List = append(funcDecl.Body.List, origList...)
p.funcDeclsToWrite = append(p.funcDeclsToWrite, funcDecl)
return true
}
func createHasGenIfStmt(funcDecl *ast.FuncDecl, coroutineParamName string) *ast.IfStmt {
return &ast.IfStmt{
Cond: createStmtFromSelFuncCall("cogo", "HasGen").(*ast.ExprStmt).X,

View File

@ -9,14 +9,8 @@ import (
"github.com/bloeys/cogo/cogo"
)
func test_cogo(c *cogo.Coroutine[int, int]) (out int) { return 0 }
func test(c *cogo.Coroutine[int, int]) (out int) {
if cogo.HasGen() {
return test_cogo(c)
}
c.Begin()
println("Tick 1")