From e0cc0c3de443cb9f7b38186dd582449ce7759ef9 Mon Sep 17 00:00:00 2001 From: bloeys Date: Fri, 28 Oct 2022 01:04:47 +0400 Subject: [PATCH] Support one level nesting of yield in ifStmts --- cogo/cogo.go | 10 ++- inliner/main.go | 160 ++++++++++++++++++++++++++++++++++++++++++++++-- main.cogo.go | 44 +++++++++++-- main.go | 16 +++-- 4 files changed, 212 insertions(+), 18 deletions(-) diff --git a/cogo/cogo.go b/cogo/cogo.go index 2d5eebb..037de8d 100755 --- a/cogo/cogo.go +++ b/cogo/cogo.go @@ -1,11 +1,14 @@ package cogo +import "fmt" + type CoroutineFunc[InT, OutT any] func(c *Coroutine[InT, OutT]) (out OutT) type Coroutine[InT, OutT any] struct { - State int32 - In InT - Func CoroutineFunc[InT, OutT] + State int32 + SubState int32 + In InT + Func CoroutineFunc[InT, OutT] } func (c *Coroutine[InT, OutT]) Begin() { @@ -22,6 +25,7 @@ func (c *Coroutine[InT, OutT]) Tick() (out OutT, done bool) { } func (c *Coroutine[InT, OutT]) Yield(out OutT) { + panic(fmt.Sprintf("Yield got called at runtime, which means the code generator was not run, you used cogo incorrectly, or cogo has a bug. Yield should NOT get called at runtime. coroutine: %+v;;; yield value: %+v;;;", c, out)) } func HasGen() bool { diff --git a/inliner/main.go b/inliner/main.go index 9eb9d6b..f9a5274 100755 --- a/inliner/main.go +++ b/inliner/main.go @@ -6,6 +6,7 @@ import ( "go/format" "go/token" "io" + "math/rand" "os" "strings" @@ -166,13 +167,49 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { }, } + subSwitchStmt := &ast.SwitchStmt{ + Tag: ast.NewIdent(coroutineParamName + ".SubState"), + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + getCaseWithStmts(nil, []ast.Stmt{}), + }, + }, + } + for i, stmt := range funcDecl.Body.List { var cogoFuncSelExpr *ast.SelectorExpr - ifStmt, ok := stmt.(*ast.IfStmt) - if ok && ifStmtIsHasGen(ifStmt) { - funcDecl.Body.List[i] = &ast.EmptyStmt{} + ifStmt, ifStmtOk := stmt.(*ast.IfStmt) + if ifStmtOk { + + if ifStmtIsHasGen(ifStmt) { + funcDecl.Body.List[i] = &ast.EmptyStmt{} + continue + } + + subStateNums := p.genCogoIfStmt(ifStmt, coroutineParamName, len(switchStmt.Body.List)) + for _, subStateNum := range subStateNums { + + subSwitchStmt.Body.List = append(subSwitchStmt.Body.List, + getCaseWithStmts( + []ast.Expr{ast.NewIdent(fmt.Sprint(subStateNum))}, + []ast.Stmt{ + &ast.BranchStmt{ + Tok: token.GOTO, + Label: ast.NewIdent(getLblNameFromSubStateNum(subStateNum)), + }, + }, + ), + ) + + var stmtToInsert ast.Stmt = &ast.LabeledStmt{ + Label: ast.NewIdent(getLblNameFromSubStateNum(subStateNum)), + Stmt: &ast.EmptyStmt{}, + } + funcDecl.Body.List = insertIntoArr(funcDecl.Body.List, i+1, stmtToInsert) + } + continue } @@ -210,12 +247,28 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i] caseStmts := make([]ast.Stmt, 0, len(stmtsSinceLastCogo)+2) + + caseStmts = append(caseStmts, subSwitchStmt) + subSwitchStmt = &ast.SwitchStmt{ + Tag: ast.NewIdent(coroutineParamName + ".SubState"), + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + getCaseWithStmts(nil, []ast.Stmt{}), + }, + }, + } + caseStmts = append(caseStmts, stmtsSinceLastCogo...) caseStmts = append(caseStmts, &ast.IncDecStmt{ Tok: token.INC, X: ast.NewIdent(coroutineParamName + ".State"), }, + &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".SubState")}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{ast.NewIdent("-1")}, + }, &ast.ReturnStmt{ Results: callExpr.Args, }, @@ -237,12 +290,28 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { stmtsToEndOfFunc := funcDecl.Body.List[lastCaseEndBodyListIndex+1:] caseStmts := make([]ast.Stmt, 0, len(stmtsToEndOfFunc)+1) + + caseStmts = append(caseStmts, subSwitchStmt) + subSwitchStmt = &ast.SwitchStmt{ + Tag: ast.NewIdent(coroutineParamName + ".SubState"), + Body: &ast.BlockStmt{ + List: []ast.Stmt{ + getCaseWithStmts(nil, []ast.Stmt{}), + }, + }, + } + caseStmts = append(caseStmts, &ast.AssignStmt{ Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".State")}, Tok: token.ASSIGN, Rhs: []ast.Expr{ast.NewIdent("-1")}, }, + &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".SubState")}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{ast.NewIdent("-1")}, + }, ) caseStmts = append(caseStmts, stmtsToEndOfFunc...) @@ -261,6 +330,11 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { Tok: token.ASSIGN, Rhs: []ast.Expr{ast.NewIdent("-1")}, }, + &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".SubState")}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{ast.NewIdent("-1")}, + }, &ast.ReturnStmt{}, }, ), @@ -282,6 +356,81 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { return true } +func (p *processor) genCogoIfStmt(ifStmt *ast.IfStmt, coroutineParamName string, currCase int) (subStateNums []int32) { + + for i, stmt := range ifStmt.Body.List { + + selExpr, selExprArgs := tryGetSelExprFromStmt(stmt, coroutineParamName, "Yield") + if selExpr == nil { + continue + } + + // @TODO: Ensure that subStateNums don't get reused + // subStateNum >= 1000_000 + newSubStateNum := rand.Int31() + 1000_000 + subStateNums = append(subStateNums, newSubStateNum) + ifStmt.Body.List[i] = &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".State")}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{ast.NewIdent(fmt.Sprint(currCase))}, + }, + &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".SubState")}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{ast.NewIdent(fmt.Sprint(newSubStateNum))}, + }, + &ast.ReturnStmt{ + Results: selExprArgs, + }, + }, + } + } + + return subStateNums +} + +func getLblNameFromSubStateNum(subStateNum int32) string { + return fmt.Sprint("cogo_", subStateNum) +} + +func insertIntoArr[T any](a []T, index int, value T) []T { + + if len(a) == index { + return append(a, value) + } + + a = append(a[:index+1], a[index:]...) + a[index] = value + return a +} + +func tryGetSelExprFromStmt(stmt ast.Stmt, lhs, rhs string) (selExpr *ast.SelectorExpr, args []ast.Expr) { + + exprStmt, ok := stmt.(*ast.ExprStmt) + if !ok { + return nil, nil + } + + callExpr, ok := exprStmt.X.(*ast.CallExpr) + if !ok { + return nil, nil + } + args = callExpr.Args + + selExpr, ok = callExpr.Fun.(*ast.SelectorExpr) + if !ok { + return nil, nil + } + + if selExprIs(selExpr, lhs, rhs) { + return selExpr, args + } + + return nil, nil +} + func (p *processor) genHasGenChecksOnOriginalFuncsNodeProcessor(c *astutil.Cursor) bool { n := c.Node() @@ -434,14 +583,14 @@ func funcCallHasLhsName(selExpr *ast.SelectorExpr, pkgName string) bool { return ok && pkgIdent.Name == pkgName } -func selExprIs(selExpr *ast.SelectorExpr, pkgName, typeName string) bool { +func selExprIs(selExpr *ast.SelectorExpr, lhs, rhs string) bool { pkgIdentExpr, ok := selExpr.X.(*ast.Ident) if !ok { return false } - return pkgIdentExpr.Name == pkgName && selExpr.Sel.Name == typeName + return pkgIdentExpr.Name == lhs && selExpr.Sel.Name == rhs } // func filter[T any](arr []T, where func(x T) bool) []T { @@ -485,6 +634,7 @@ func writeAst(fName, topComment string, fset *token.FileSet, node any) { b, err := imports.Process(fName, nil, nil) if err != nil { + format.Node(os.Stdout, fset, node) panic("Failed to process imports on file " + fName + ". Err: " + err.Error()) } diff --git a/main.cogo.go b/main.cogo.go index c484d9b..c282fc8 100755 --- a/main.cogo.go +++ b/main.cogo.go @@ -6,33 +6,67 @@ import "github.com/bloeys/cogo/cogo" func test_cogo(c *cogo.Coroutine[int, int]) (out int) { switch c.State { case 0: + switch c.SubState { + default: + } - println("Tick 1") + println("\nTick 1") c.State++ + c.SubState = -1 return 1 case 1: + switch c.SubState { + default: + case 1299498081: + goto cogo_1299498081 + } - println("Tick 2") + if c.In > 1 { + println("\nTick 1.5") + { + c.State = 1 + c.SubState = 1299498081 + return c.In + } + } + cogo_1299498081: + ; + + println("\nTick 2") c.State++ + c.SubState = -1 return 2 case 2: + switch c.SubState { + default: + } - println("Tick 3") + println("\nTick 3") c.State++ + c.SubState = -1 return 3 case 3: + switch c.SubState { + default: + } - println("Tick 4") + println("\nTick 4") c.State++ + c.SubState = -1 return 4 case 4: + switch c.SubState { + default: + } c.State = -1 + c.SubState = -1 - println("Tick before end") + println("\nTick before end") return out default: c.State = -1 + c.SubState = -1 return } } diff --git a/main.go b/main.go index c873901..a781b75 100755 --- a/main.go +++ b/main.go @@ -16,19 +16,24 @@ func test(c *cogo.Coroutine[int, int]) (out int) { c.Begin() - println("Tick 1") + println("\nTick 1") c.Yield(1) - println("Tick 2") + if c.In > 1 { + println("\nTick 1.5") + c.Yield(c.In) + } + + println("\nTick 2") c.Yield(2) - println("Tick 3") + println("\nTick 3") c.Yield(3) - println("Tick 4") + println("\nTick 4") c.Yield(4) - println("Tick before end") + println("\nTick before end") return out } @@ -40,6 +45,7 @@ func main() { In: 0, } + c.In = 5 for out, done := c.Tick(); !done; out, done = c.Tick() { println(out) }