mirror of
https://github.com/bloeys/cogo.git
synced 2025-12-29 08:58:19 +00:00
Work
This commit is contained in:
171
inliner/main.go
171
inliner/main.go
@ -5,10 +5,13 @@ import (
|
|||||||
"go/ast"
|
"go/ast"
|
||||||
"go/format"
|
"go/format"
|
||||||
"go/token"
|
"go/token"
|
||||||
|
"io"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/tools/go/ast/astutil"
|
"golang.org/x/tools/go/ast/astutil"
|
||||||
"golang.org/x/tools/go/packages"
|
"golang.org/x/tools/go/packages"
|
||||||
|
"golang.org/x/tools/imports"
|
||||||
)
|
)
|
||||||
|
|
||||||
func printDebugInfo() {
|
func printDebugInfo() {
|
||||||
@ -32,7 +35,12 @@ func main() {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse package
|
genCogoFuncs(cwd)
|
||||||
|
genHasGenChecksOnOriginalFuncs(cwd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func genCogoFuncs(cwd string) {
|
||||||
|
|
||||||
pkgs, err := packages.Load(&packages.Config{
|
pkgs, err := packages.Load(&packages.Config{
|
||||||
Dir: cwd,
|
Dir: cwd,
|
||||||
Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax,
|
Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax,
|
||||||
@ -42,46 +50,104 @@ func main() {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(pkgs) != 1 {
|
// if len(pkgs) != 1 {
|
||||||
panic(fmt.Sprintf("expected to find one package but found %d", len(pkgs)))
|
// panic(fmt.Sprintf("expected to find one package but found %d", len(pkgs)))
|
||||||
}
|
// }
|
||||||
|
|
||||||
processPkg(pkgs[0])
|
for _, pkg := range pkgs {
|
||||||
|
|
||||||
|
p := &processor{
|
||||||
|
fset: pkg.Fset,
|
||||||
|
funcDeclsToWrite: []*ast.FuncDecl{},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, synFile := range pkg.Syntax {
|
||||||
|
|
||||||
|
pkg.Syntax[i] = astutil.Apply(synFile, p.genCogoFuncsNodeProcessor, nil).(*ast.File)
|
||||||
|
|
||||||
|
if len(p.funcDeclsToWrite) > 0 {
|
||||||
|
|
||||||
|
root := &ast.File{
|
||||||
|
Name: synFile.Name,
|
||||||
|
Imports: synFile.Imports,
|
||||||
|
Decls: []ast.Decl{},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range p.funcDeclsToWrite {
|
||||||
|
root.Decls = append(root.Decls, &ast.FuncDecl{
|
||||||
|
Name: ast.NewIdent(v.Name.Name + "_cogo"),
|
||||||
|
Type: v.Type,
|
||||||
|
Body: v.Body,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
origFName := pkg.Fset.File(synFile.Pos()).Name()
|
||||||
|
newFName := strings.Split(origFName, ".")[0] + ".cogo.go"
|
||||||
|
writeAst(newFName, pkg.Fset, root)
|
||||||
|
|
||||||
|
p.funcDeclsToWrite = p.funcDeclsToWrite[:0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func processPkg(pkg *packages.Package) {
|
func writeAst(fName string, fset *token.FileSet, node any) {
|
||||||
|
|
||||||
p := processor{
|
f, err := os.Create(fName)
|
||||||
fset: pkg.Fset,
|
if err != nil {
|
||||||
funcDeclsToWrite: []*ast.FuncDecl{},
|
panic("Failed to create file to write new AST. Err: " + err.Error())
|
||||||
}
|
}
|
||||||
for i, synFile := range pkg.Syntax {
|
defer f.Close()
|
||||||
|
|
||||||
pkg.Syntax[i] = astutil.Apply(synFile, p.processDeclNode, nil).(*ast.File)
|
err = format.Node(f, fset, node)
|
||||||
|
if err != nil {
|
||||||
|
panic(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
if len(p.funcDeclsToWrite) > 0 {
|
b, err := imports.Process(fName, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
panic("Failed to process imports on file " + fName + ". Err: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
root := &ast.File{
|
f.Seek(0, io.SeekStart)
|
||||||
Name: synFile.Name,
|
_, err = f.Write(b)
|
||||||
Imports: synFile.Imports,
|
if err != nil {
|
||||||
Decls: []ast.Decl{},
|
panic(err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func genHasGenChecksOnOriginalFuncs(cwd string) {
|
||||||
|
|
||||||
|
pkgs, err := packages.Load(&packages.Config{
|
||||||
|
Dir: cwd,
|
||||||
|
Mode: packages.NeedName | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax,
|
||||||
|
Tests: false,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// if len(pkgs) != 1 {
|
||||||
|
// panic(fmt.Sprintf("expected to find one package but found %d", len(pkgs)))
|
||||||
|
// }
|
||||||
|
|
||||||
|
for _, pkg := range pkgs {
|
||||||
|
|
||||||
|
p := &processor{
|
||||||
|
fset: pkg.Fset,
|
||||||
|
funcDeclsToWrite: []*ast.FuncDecl{},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, synFile := range pkg.Syntax {
|
||||||
|
|
||||||
|
pkg.Syntax[i] = astutil.Apply(synFile, p.genHasGenChecksOnOriginalFuncsNodeProcessor, nil).(*ast.File)
|
||||||
|
|
||||||
|
if len(p.funcDeclsToWrite) > 0 {
|
||||||
|
origFName := pkg.Fset.File(synFile.Pos()).Name()
|
||||||
|
writeAst(origFName, pkg.Fset, pkg.Syntax[i])
|
||||||
|
|
||||||
|
p.funcDeclsToWrite = p.funcDeclsToWrite[:0]
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, v := range p.funcDeclsToWrite {
|
|
||||||
root.Decls = append(root.Decls, &ast.FuncDecl{
|
|
||||||
Name: ast.NewIdent(v.Name.Name + "_cogo"),
|
|
||||||
Type: v.Type,
|
|
||||||
Body: v.Body,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// imports.Process()
|
|
||||||
err := format.Node(os.Stdout, pkg.Fset, root)
|
|
||||||
if err != nil {
|
|
||||||
panic(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
p.funcDeclsToWrite = p.funcDeclsToWrite[:0]
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -92,7 +158,7 @@ type processor struct {
|
|||||||
funcDeclsToWrite []*ast.FuncDecl
|
funcDeclsToWrite []*ast.FuncDecl
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *processor) processDeclNode(c *astutil.Cursor) bool {
|
func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
|
||||||
|
|
||||||
n := c.Node()
|
n := c.Node()
|
||||||
if n == nil {
|
if n == nil {
|
||||||
@ -226,6 +292,45 @@ func (p *processor) processDeclNode(c *astutil.Cursor) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *processor) genHasGenChecksOnOriginalFuncsNodeProcessor(c *astutil.Cursor) bool {
|
||||||
|
|
||||||
|
n := c.Node()
|
||||||
|
if n == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
funcDecl, ok := n.(*ast.FuncDecl)
|
||||||
|
if !ok || funcDecl.Body == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
coroutineParamName := getCoroutineParamNameFromFuncDecl(funcDecl)
|
||||||
|
if coroutineParamName == "" || !funcDeclCallsCoroutineBegin(funcDecl, coroutineParamName) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, stmt := range funcDecl.Body.List {
|
||||||
|
|
||||||
|
// If one already exists update it and return
|
||||||
|
ifStmt, ok := stmt.(*ast.IfStmt)
|
||||||
|
if ok && ifStmtIsHasGen(ifStmt) {
|
||||||
|
funcDecl.Body.List[i] = createHasGenIfStmt(funcDecl, coroutineParamName)
|
||||||
|
p.funcDeclsToWrite = append(p.funcDeclsToWrite, funcDecl)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the check doesn't exist add it to the beginning of the function
|
||||||
|
origList := funcDecl.Body.List
|
||||||
|
funcDecl.Body.List = make([]ast.Stmt, 0, len(origList)+1)
|
||||||
|
funcDecl.Body.List = append(funcDecl.Body.List, createHasGenIfStmt(funcDecl, coroutineParamName))
|
||||||
|
funcDecl.Body.List = append(funcDecl.Body.List, origList...)
|
||||||
|
|
||||||
|
p.funcDeclsToWrite = append(p.funcDeclsToWrite, funcDecl)
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func createHasGenIfStmt(funcDecl *ast.FuncDecl, coroutineParamName string) *ast.IfStmt {
|
func createHasGenIfStmt(funcDecl *ast.FuncDecl, coroutineParamName string) *ast.IfStmt {
|
||||||
return &ast.IfStmt{
|
return &ast.IfStmt{
|
||||||
Cond: createStmtFromSelFuncCall("cogo", "HasGen").(*ast.ExprStmt).X,
|
Cond: createStmtFromSelFuncCall("cogo", "HasGen").(*ast.ExprStmt).X,
|
||||||
|
|||||||
6
main.go
6
main.go
@ -9,14 +9,8 @@ import (
|
|||||||
"github.com/bloeys/cogo/cogo"
|
"github.com/bloeys/cogo/cogo"
|
||||||
)
|
)
|
||||||
|
|
||||||
func test_cogo(c *cogo.Coroutine[int, int]) (out int) { return 0 }
|
|
||||||
|
|
||||||
func test(c *cogo.Coroutine[int, int]) (out int) {
|
func test(c *cogo.Coroutine[int, int]) (out int) {
|
||||||
|
|
||||||
if cogo.HasGen() {
|
|
||||||
return test_cogo(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Begin()
|
c.Begin()
|
||||||
|
|
||||||
println("Tick 1")
|
println("Tick 1")
|
||||||
|
|||||||
Reference in New Issue
Block a user