diff --git a/inliner/main.go b/inliner/main.go index c8cac04..12931ae 100755 --- a/inliner/main.go +++ b/inliner/main.go @@ -5,10 +5,13 @@ import ( "go/ast" "go/format" "go/token" + "io" "os" + "strings" "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/packages" + "golang.org/x/tools/imports" ) func printDebugInfo() { @@ -32,7 +35,12 @@ func main() { panic(err) } - // Parse package + genCogoFuncs(cwd) + genHasGenChecksOnOriginalFuncs(cwd) +} + +func genCogoFuncs(cwd string) { + pkgs, err := packages.Load(&packages.Config{ Dir: cwd, Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax, @@ -42,46 +50,104 @@ func main() { panic(err) } - if len(pkgs) != 1 { - panic(fmt.Sprintf("expected to find one package but found %d", len(pkgs))) - } + // if len(pkgs) != 1 { + // panic(fmt.Sprintf("expected to find one package but found %d", len(pkgs))) + // } - processPkg(pkgs[0]) + for _, pkg := range pkgs { + + p := &processor{ + fset: pkg.Fset, + funcDeclsToWrite: []*ast.FuncDecl{}, + } + + for i, synFile := range pkg.Syntax { + + pkg.Syntax[i] = astutil.Apply(synFile, p.genCogoFuncsNodeProcessor, nil).(*ast.File) + + if len(p.funcDeclsToWrite) > 0 { + + root := &ast.File{ + Name: synFile.Name, + Imports: synFile.Imports, + Decls: []ast.Decl{}, + } + + for _, v := range p.funcDeclsToWrite { + root.Decls = append(root.Decls, &ast.FuncDecl{ + Name: ast.NewIdent(v.Name.Name + "_cogo"), + Type: v.Type, + Body: v.Body, + }) + } + + origFName := pkg.Fset.File(synFile.Pos()).Name() + newFName := strings.Split(origFName, ".")[0] + ".cogo.go" + writeAst(newFName, pkg.Fset, root) + + p.funcDeclsToWrite = p.funcDeclsToWrite[:0] + } + } + } } -func processPkg(pkg *packages.Package) { +func writeAst(fName string, fset *token.FileSet, node any) { - p := processor{ - fset: pkg.Fset, - funcDeclsToWrite: []*ast.FuncDecl{}, + f, err := os.Create(fName) + if err != nil { + panic("Failed to create file to write new AST. Err: " + err.Error()) } - for i, synFile := range pkg.Syntax { + defer f.Close() - pkg.Syntax[i] = astutil.Apply(synFile, p.processDeclNode, nil).(*ast.File) + err = format.Node(f, fset, node) + if err != nil { + panic(err.Error()) + } - if len(p.funcDeclsToWrite) > 0 { + b, err := imports.Process(fName, nil, nil) + if err != nil { + panic("Failed to process imports on file " + fName + ". Err: " + err.Error()) + } - root := &ast.File{ - Name: synFile.Name, - Imports: synFile.Imports, - Decls: []ast.Decl{}, + f.Seek(0, io.SeekStart) + _, err = f.Write(b) + if err != nil { + panic(err.Error()) + } +} + +func genHasGenChecksOnOriginalFuncs(cwd string) { + + pkgs, err := packages.Load(&packages.Config{ + Dir: cwd, + Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax, + Tests: false, + }) + if err != nil { + panic(err) + } + + // if len(pkgs) != 1 { + // panic(fmt.Sprintf("expected to find one package but found %d", len(pkgs))) + // } + + for _, pkg := range pkgs { + + p := &processor{ + fset: pkg.Fset, + funcDeclsToWrite: []*ast.FuncDecl{}, + } + + for i, synFile := range pkg.Syntax { + + pkg.Syntax[i] = astutil.Apply(synFile, p.genHasGenChecksOnOriginalFuncsNodeProcessor, nil).(*ast.File) + + if len(p.funcDeclsToWrite) > 0 { + origFName := pkg.Fset.File(synFile.Pos()).Name() + writeAst(origFName, pkg.Fset, pkg.Syntax[i]) + + p.funcDeclsToWrite = p.funcDeclsToWrite[:0] } - - for _, v := range p.funcDeclsToWrite { - root.Decls = append(root.Decls, &ast.FuncDecl{ - Name: ast.NewIdent(v.Name.Name + "_cogo"), - Type: v.Type, - Body: v.Body, - }) - } - - // imports.Process() - err := format.Node(os.Stdout, pkg.Fset, root) - if err != nil { - panic(err.Error()) - } - - p.funcDeclsToWrite = p.funcDeclsToWrite[:0] } } @@ -92,7 +158,7 @@ type processor struct { funcDeclsToWrite []*ast.FuncDecl } -func (p *processor) processDeclNode(c *astutil.Cursor) bool { +func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { n := c.Node() if n == nil { @@ -226,6 +292,45 @@ func (p *processor) processDeclNode(c *astutil.Cursor) bool { return true } +func (p *processor) genHasGenChecksOnOriginalFuncsNodeProcessor(c *astutil.Cursor) bool { + + n := c.Node() + if n == nil { + return false + } + + funcDecl, ok := n.(*ast.FuncDecl) + if !ok || funcDecl.Body == nil { + return true + } + + coroutineParamName := getCoroutineParamNameFromFuncDecl(funcDecl) + if coroutineParamName == "" || !funcDeclCallsCoroutineBegin(funcDecl, coroutineParamName) { + return false + } + + for i, stmt := range funcDecl.Body.List { + + // If one already exists update it and return + ifStmt, ok := stmt.(*ast.IfStmt) + if ok && ifStmtIsHasGen(ifStmt) { + funcDecl.Body.List[i] = createHasGenIfStmt(funcDecl, coroutineParamName) + p.funcDeclsToWrite = append(p.funcDeclsToWrite, funcDecl) + return true + } + } + + // If the check doesn't exist add it to the beginning of the function + origList := funcDecl.Body.List + funcDecl.Body.List = make([]ast.Stmt, 0, len(origList)+1) + funcDecl.Body.List = append(funcDecl.Body.List, createHasGenIfStmt(funcDecl, coroutineParamName)) + funcDecl.Body.List = append(funcDecl.Body.List, origList...) + + p.funcDeclsToWrite = append(p.funcDeclsToWrite, funcDecl) + + return true +} + func createHasGenIfStmt(funcDecl *ast.FuncDecl, coroutineParamName string) *ast.IfStmt { return &ast.IfStmt{ Cond: createStmtFromSelFuncCall("cogo", "HasGen").(*ast.ExprStmt).X, diff --git a/main.go b/main.go index ba881c6..4851e62 100755 --- a/main.go +++ b/main.go @@ -9,14 +9,8 @@ import ( "github.com/bloeys/cogo/cogo" ) -func test_cogo(c *cogo.Coroutine[int, int]) (out int) { return 0 } - func test(c *cogo.Coroutine[int, int]) (out int) { - if cogo.HasGen() { - return test_cogo(c) - } - c.Begin() println("Tick 1")