Masha Allah it works!

This commit is contained in:
bloeys
2022-10-27 05:58:58 +04:00
parent d389e76b8d
commit 94090dc9ab
2 changed files with 56 additions and 49 deletions

View File

@ -83,7 +83,7 @@ func genCogoFuncs(cwd string) {
origFName := pkg.Fset.File(synFile.Pos()).Name()
newFName := strings.Split(origFName, ".")[0] + ".cogo.go"
writeAst(newFName, pkg.Fset, root)
writeAst(newFName, "// Code generated by 'cogo'; DO NOT EDIT.\n", pkg.Fset, root)
p.funcDeclsToWrite = p.funcDeclsToWrite[:0]
}
@ -91,31 +91,6 @@ func genCogoFuncs(cwd string) {
}
}
func writeAst(fName string, fset *token.FileSet, node any) {
f, err := os.Create(fName)
if err != nil {
panic("Failed to create file to write new AST. Err: " + err.Error())
}
defer f.Close()
err = format.Node(f, fset, node)
if err != nil {
panic(err.Error())
}
b, err := imports.Process(fName, nil, nil)
if err != nil {
panic("Failed to process imports on file " + fName + ". Err: " + err.Error())
}
f.Seek(0, io.SeekStart)
_, err = f.Write(b)
if err != nil {
panic(err.Error())
}
}
func genHasGenChecksOnOriginalFuncs(cwd string) {
pkgs, err := packages.Load(&packages.Config{
@ -144,7 +119,7 @@ func genHasGenChecksOnOriginalFuncs(cwd string) {
if len(p.funcDeclsToWrite) > 0 {
origFName := pkg.Fset.File(synFile.Pos()).Name()
writeAst(origFName, pkg.Fset, pkg.Syntax[i])
writeAst(origFName, "", pkg.Fset, pkg.Syntax[i])
p.funcDeclsToWrite = p.funcDeclsToWrite[:0]
}
@ -184,15 +159,13 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
},
}
hasGenCheckExists := false
for i, stmt := range funcDecl.Body.List {
var cogoFuncSelExpr *ast.SelectorExpr
ifStmt, ok := stmt.(*ast.IfStmt)
if ok && ifStmtIsHasGen(ifStmt) {
funcDecl.Body.List[i] = createHasGenIfStmt(funcDecl, coroutineParamName)
hasGenCheckExists = true
funcDecl.Body.List[i] = &ast.EmptyStmt{}
continue
}
@ -271,6 +244,19 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
[]ast.Expr{ast.NewIdent(fmt.Sprint(len(switchStmt.Body.List)))},
caseStmts,
),
// default case
getCaseWithStmts(
nil,
[]ast.Stmt{
&ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".State")},
Tok: token.ASSIGN,
Rhs: []ast.Expr{ast.NewIdent("-1")},
},
&ast.ReturnStmt{},
},
),
)
// Apply changes
@ -282,9 +268,6 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
originalList := funcDecl.Body.List
funcDecl.Body.List = make([]ast.Stmt, 0, len(funcDecl.Body.List)+1)
if !hasGenCheckExists {
funcDecl.Body.List = append(funcDecl.Body.List, createHasGenIfStmt(funcDecl, coroutineParamName))
}
funcDecl.Body.List = append(funcDecl.Body.List, originalList...)
p.funcDeclsToWrite = append(p.funcDeclsToWrite, funcDecl)
@ -362,14 +345,6 @@ func ifStmtIsHasGen(stmt *ast.IfStmt) bool {
return selExprIs(selExpr, "cogo", "HasGen")
}
func createStmtFromFuncCall(funcName string) ast.Stmt {
return &ast.ExprStmt{
X: &ast.CallExpr{
Fun: ast.NewIdent(funcName),
},
}
}
func createStmtFromSelFuncCall(lhs, rhs string) ast.Stmt {
return &ast.ExprStmt{
X: &ast.CallExpr{
@ -483,3 +458,32 @@ func getCaseWithStmts(caseConditions []ast.Expr, stmts []ast.Stmt) *ast.CaseClau
Body: stmts,
}
}
func writeAst(fName, topComment string, fset *token.FileSet, node any) {
f, err := os.Create(fName)
if err != nil {
panic("Failed to create file to write new AST. Err: " + err.Error())
}
defer f.Close()
if topComment != "" {
f.WriteString(topComment)
}
err = format.Node(f, fset, node)
if err != nil {
panic(err.Error())
}
b, err := imports.Process(fName, nil, nil)
if err != nil {
panic("Failed to process imports on file " + fName + ". Err: " + err.Error())
}
f.Seek(0, io.SeekStart)
_, err = f.Write(b)
if err != nil {
panic(err.Error())
}
}

19
main.go
View File

@ -10,22 +10,25 @@ import (
)
func test(c *cogo.Coroutine[int, int]) (out int) {
if cogo.HasGen() {
return test_cogo(c)
}
c.Begin()
println("Tick 1")
// c.Yield(1)
c.Yield(1)
// println("Tick 2")
// c.Yield(2)
println("Tick 2")
c.Yield(2)
// println("Tick 3")
// c.Yield(3)
println("Tick 3")
c.Yield(3)
// println("Tick 4")
// c.Yield(4)
println("Tick 4")
c.Yield(4)
// println("Tick before end")
println("Tick before end")
return out
}