diff --git a/demo.cogo.go b/demo.cogo.go index 6aed80c..5b1ce74 100755 --- a/demo.cogo.go +++ b/demo.cogo.go @@ -9,97 +9,46 @@ import ( func test_cogo(c *cogo.Coroutine[int, int]) { switch c.State { - case 0: - switch c.SubState { - default: - } + case 1: + goto cogo_1_1 + case 2: + c.State = 1 + goto cogo_2_0 + case 3: + goto cogo_1_3 + } - println("test yield:", 1) - c.State++ - c.SubState = -1 + println("test yield:", 1) + { + c.State = 1 c.Out = 1 return - case 1: - switch c.SubState { - default: - } - c.State++ - c.SubState = -1 - c.Yielder = cogo.NewSleeper(100 * time.Millisecond) - return - case 2: - switch c.SubState { - default: - } - c.State++ - c.SubState = -1 - c.Yielder = cogo.New(test2, 0) - return - case 3: - switch c.SubState { - default: - } + } +cogo_1_1: + ; +cogo_2_0: + ; - println("test yield:", 2) - c.State++ - c.SubState = -1 + if c.Out > 2 { + switch c.State { + case 1: + goto cogo_2_1 + } + { + c.State = 2 + c.Out = 1 + return + } + cogo_2_1: + } + + c.YieldTo(cogo.NewSleeper(100 * time.Millisecond)) + + println("test yield:", 2) + { + c.State = 3 c.Out = 2 return - case 4: - switch c.SubState { - default: - } - c.State = -1 - c.SubState = -1 - default: - c.State = -1 - c.SubState = -1 - return - } -} - -func test2_cogo(c *cogo.Coroutine[int, int]) { - switch c.State { - case 0: - switch c.SubState { - default: - } - - println("test2222 yield:", 1) - c.State++ - c.SubState = -1 - c.Out = 1 - return - case 1: - switch c.SubState { - default: - } - - println("test2222 yield:", 2) - c.State++ - c.SubState = -1 - c.Out = 2 - return - case 2: - switch c.SubState { - default: - } - - println("test2222 before yield none") - c.State++ - c.SubState = -1 - return - case 3: - switch c.SubState { - default: - } - c.State = -1 - c.SubState = -1 - - println("test2222 after yield none") - default: - c.State = -1 - c.SubState = -1 - return } +cogo_1_3: } diff --git a/demo.go b/demo.go index 7896535..2938fe5 100755 --- a/demo.go +++ b/demo.go @@ -23,41 +23,73 @@ func runDemo() { } func test(c *cogo.Coroutine[int, int]) { - if cogo.HasGen() { - test_cogo(c) - return - } - - c.Begin() println("test yield:", 1) c.Yield(1) + if c.Out > 2 { + c.Yield(1) + } + // Yield here until at least 100ms passed c.YieldTo(cogo.NewSleeper(100 * time.Millisecond)) // Yield here until the coroutine 'test2' has finished - c.YieldTo(cogo.New(test2, 0)) + // c.YieldTo(cogo.New(test2, 0)) println("test yield:", 2) c.Yield(2) } -func test2(c *cogo.Coroutine[int, int]) { - if cogo.HasGen() { - test2_cogo(c) - return +// func test2(c *cogo.Coroutine[int, int]) { + +// println("test2222 yield:", 1) +// c.Yield(1) + +// println("test2222 yield:", 2) +// c.Yield(2) + +// println("test2222 before yield none") +// c.YieldNone() +// println("test2222 after yield none") +// } + +func NewApproach(state int) { + + switch state { + case 1: + goto lbl_1 + case 2: + goto lbl_2 + case 3: + state = 1 + goto lbl_2 + default: } - c.Begin() + println("1") + println("2") + state = 1 + // return - println("test2222 yield:", 1) - c.Yield(1) +lbl_1: + println("3") + state = 2 + // return - println("test2222 yield:", 2) - c.Yield(2) +lbl_2: + { + switch state { + case 1: + goto lbl_3 - println("test2222 before yield none") - c.YieldNone() - println("test2222 after yield none") + default: + } + + println("4") + state = 3 + // return + + lbl_3: + } } diff --git a/main.go b/main.go index 8a46cd0..c66ce82 100755 --- a/main.go +++ b/main.go @@ -35,7 +35,7 @@ func main() { } genCogoFuncs(cwd) - genHasGenChecksOnOriginalFuncs(cwd) + // genHasGenChecksOnOriginalFuncs(cwd) } func genCogoFuncs(cwd string) { @@ -54,10 +54,12 @@ func genCogoFuncs(cwd string) { p := &processor{ fset: pkg.Fset, funcDeclsToWrite: []*ast.FuncDecl{}, + BlockInfos: make([]BlockInfo, 0, 10), } for i, synFile := range pkg.Syntax { - pkg.Syntax[i] = astutil.Apply(synFile, p.genCogoFuncsNodeProcessor, nil).(*ast.File) + pkg.Syntax[i] = astutil.Apply(synFile, p.nodeProcessor, nil).(*ast.File) + // pkg.Syntax[i] = astutil.Apply(synFile, p.genCogoFuncsNodeProcessor, nil).(*ast.File) if len(p.funcDeclsToWrite) > 0 { @@ -101,6 +103,7 @@ func genHasGenChecksOnOriginalFuncs(cwd string) { p := &processor{ fset: pkg.Fset, funcDeclsToWrite: []*ast.FuncDecl{}, + BlockInfos: make([]BlockInfo, 0, 10), } for i, synFile := range pkg.Syntax { @@ -118,9 +121,201 @@ func genHasGenChecksOnOriginalFuncs(cwd string) { } +type BlockInfo struct { + Switch *ast.SwitchStmt + BeforeBlockLblName string +} + +func (s *BlockInfo) addCase(caseStmts []ast.Stmt) int { + + oldCaseCount := s.CaseCount() + newCase := getCaseWithStmts([]ast.Expr{ + ast.NewIdent(fmt.Sprint(oldCaseCount + 1))}, + caseStmts, + ) + + s.Switch.Body.List = append(s.Switch.Body.List, newCase) + return len(s.Switch.Body.List) +} + +func (s *BlockInfo) CaseCount() int { + return len(s.Switch.Body.List) +} + type processor struct { fset *token.FileSet funcDeclsToWrite []*ast.FuncDecl + BlockInfos []BlockInfo +} + +func (p *processor) nodeProcessor(c *astutil.Cursor) bool { + + n := c.Node() + if n == nil { + return false + } + + // If not a function or it's empty skip and continue reading the file AST + funcDecl, ok := n.(*ast.FuncDecl) + if !ok || funcDecl.Body == nil || len(funcDecl.Body.List) == 0 { + return true + } + + // Check if function has the required params + coroutineParamName := getCoroutineParamNameFromFuncDecl(funcDecl) + if coroutineParamName == "" { + return false + } + + if !blockUsesCogo(funcDecl.Body, coroutineParamName, true) { + return false + } + + // Generate code for function + p.processBlock(nil, funcDecl.Body, -1, coroutineParamName) + p.funcDeclsToWrite = append(p.funcDeclsToWrite, funcDecl) + return false +} + +func (p *processor) processBlock(parentBlock, blockStmt *ast.BlockStmt, indexInParent int, coroutineParamName string) { + + if !blockUsesCogo(blockStmt, coroutineParamName, true) { + return + } + + p.pushNewBlock(parentBlock, blockStmt, indexInParent, &ast.SwitchStmt{ + Tag: ast.NewIdent(coroutineParamName + ".State"), + Body: &ast.BlockStmt{ + List: []ast.Stmt{}, + }, + }) + defer p.popSwitch() + + for i := 0; i < len(blockStmt.List); i++ { + + stmt := blockStmt.List[i] + + if ifStmt, ifStmtOk := stmt.(*ast.IfStmt); ifStmtOk { + + p.processBlock(blockStmt, ifStmt.Body, i, coroutineParamName) + + } else if forStmt, forStmtOk := stmt.(*ast.ForStmt); forStmtOk { + + // @TODO: For loops need unique handling to convert to an if statement (if they use cogo) + p.processBlock(blockStmt, forStmt.Body, i, coroutineParamName) + + } else if bStmt, blockStmtOk := stmt.(*ast.BlockStmt); blockStmtOk { + + p.processBlock(blockStmt, bStmt, i, coroutineParamName) + } + + selExpr, selExprArgs := tryGetSelExprFromStmt(stmt, coroutineParamName, "Yield") + if selExpr == nil { + continue + } + + p.addYield(blockStmt, &i, selExprArgs, coroutineParamName) + } + +} + +func (p *processor) pushNewBlock(parentBlock, blockStmt *ast.BlockStmt, indexInParent int, switchStmt *ast.SwitchStmt) { + + beforeBlockLblName := getLblNameFromSubStateNum(int32(len(p.BlockInfos)+1), 0) + if parentBlock != nil { + parentBlock.List = insertIntoArr[ast.Stmt](parentBlock.List, indexInParent, &ast.LabeledStmt{ + Label: ast.NewIdent(beforeBlockLblName), + Stmt: &ast.EmptyStmt{}, + }) + } + + // Add switch to start of block + blockStmt.List = insertIntoArr[ast.Stmt](blockStmt.List, 0, switchStmt) + + p.BlockInfos = append(p.BlockInfos, BlockInfo{ + Switch: switchStmt, + BeforeBlockLblName: beforeBlockLblName, + }) +} + +func (p *processor) currSwitch() *ast.SwitchStmt { + return p.BlockInfos[len(p.BlockInfos)-1].Switch +} + +func (p *processor) currSwitchInfo() *BlockInfo { + return &p.BlockInfos[len(p.BlockInfos)-1] +} + +func (p *processor) popSwitch() *ast.SwitchStmt { + + s := p.BlockInfos[len(p.BlockInfos)-1] + p.BlockInfos = p.BlockInfos[:len(p.BlockInfos)-1] + return s.Switch +} + +func toStr[T any](x T) string { + return fmt.Sprintf("%+v", x) +} + +func (p *processor) addYield(block *ast.BlockStmt, listIndex *int, yieldArgs []ast.Expr, coroutineParamName string) { + + // Update switch statements and find the new state value + newCaseCondition := p.currSwitchInfo().CaseCount() + 1 + newLblName := getLblNameFromSubStateNum(int32(len(p.BlockInfos)), int32(newCaseCondition)) + + // Go over all switch cases from current inner block to outer most + // and add a new case in each to make us reach this new yield + p.currSwitchInfo().addCase([]ast.Stmt{ + &ast.BranchStmt{ + Tok: token.GOTO, + Label: ast.NewIdent(newLblName), + }, + }) + + stateValNeededForNexSwitch := newCaseCondition + prevBlockBeforeLblName := p.currSwitchInfo().BeforeBlockLblName + for i := len(p.BlockInfos) - 2; i >= 0; i-- { + + info := &p.BlockInfos[i] + info.addCase([]ast.Stmt{ + &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".State")}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{ast.NewIdent(toStr(stateValNeededForNexSwitch))}, + }, + &ast.BranchStmt{ + Tok: token.GOTO, + Label: ast.NewIdent(prevBlockBeforeLblName), + }, + }) + + stateValNeededForNexSwitch = info.CaseCount() + prevBlockBeforeLblName = info.BeforeBlockLblName + } + + // Create and add yield block + newBlock := &ast.BlockStmt{ + List: []ast.Stmt{ + &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".State")}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{ast.NewIdent(toStr(stateValNeededForNexSwitch))}, + }, + &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".Out")}, + Tok: token.ASSIGN, + Rhs: yieldArgs, + }, + &ast.ReturnStmt{}, + }, + } + + block.List[*listIndex] = newBlock + block.List = insertIntoArr[ast.Stmt](block.List, *listIndex+1, &ast.LabeledStmt{ + Label: ast.NewIdent(newLblName), + Stmt: &ast.EmptyStmt{}, + }) + *listIndex++ } func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { @@ -140,8 +335,8 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { return false } - hasBegin := funcHasSelInBody(funcDecl, coroutineParamName, "Begin") - hasYield := funcHasSelInBody(funcDecl, coroutineParamName, "Yield") + hasBegin := blockHasOneOrMoreSels(funcDecl.Body, []SelExprInfo{{coroutineParamName, "Begin"}}, false) + hasYield := blockHasOneOrMoreSels(funcDecl.Body, []SelExprInfo{{coroutineParamName, "Yield"}}, false) if !hasBegin && !hasYield { return false } @@ -152,7 +347,7 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { beginBodyListIndex := -1 lastCaseEndBodyListIndex := -1 - switchStmt := &ast.SwitchStmt{ + mainSwitchStmt := &ast.SwitchStmt{ Tag: ast.NewIdent(coroutineParamName + ".State"), Body: &ast.BlockStmt{ List: []ast.Stmt{}, @@ -172,8 +367,8 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { stmt := funcDecl.Body.List[i] - var cogoFuncSelExpr *ast.SelectorExpr var blockStmt *ast.BlockStmt + var selExpr *ast.SelectorExpr if ifStmt, ifStmtOk := stmt.(*ast.IfStmt); ifStmtOk { @@ -186,7 +381,7 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { } else if forStmt, forStmtOk := stmt.(*ast.ForStmt); forStmtOk { - outInitBlock, postInitStmt, outIfStmt, subStateNums := p.genCogoForStmt(forStmt, coroutineParamName, len(switchStmt.Body.List)) + outInitBlock, postInitStmt, outIfStmt, subStateNums := p.genCogoForStmt(forStmt, coroutineParamName, len(mainSwitchStmt.Body.List)) if len(subStateNums) > 1 { panic("For loops currently don't support more than one yield") @@ -202,7 +397,7 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { postInitStmt, &ast.BranchStmt{ Tok: token.GOTO, - Label: ast.NewIdent(getLblNameFromSubStateNum(subStateNums[0])), + Label: ast.NewIdent(getLblNameFromSubStateNum(1, subStateNums[0])), }, }, ), @@ -213,7 +408,7 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { // // Insert lable after initial condition var stmtInterface ast.Stmt = &ast.LabeledStmt{ - Label: ast.NewIdent(getLblNameFromSubStateNum(subStateNums[0])), + Label: ast.NewIdent(getLblNameFromSubStateNum(1, subStateNums[0])), Stmt: &ast.EmptyStmt{}, } funcDecl.Body.List = insertIntoArr(funcDecl.Body.List, i+1, stmtInterface) @@ -231,7 +426,7 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { if blockStmt != nil { - subStateNums := p.genCogoBlockStmt(blockStmt, coroutineParamName, len(switchStmt.Body.List)) + subStateNums := p.genCogoBlockStmt(blockStmt, coroutineParamName, len(mainSwitchStmt.Body.List)) for _, subStateNum := range subStateNums { subSwitchStmt.Body.List = append(subSwitchStmt.Body.List, @@ -240,14 +435,14 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { []ast.Stmt{ &ast.BranchStmt{ Tok: token.GOTO, - Label: ast.NewIdent(getLblNameFromSubStateNum(subStateNum)), + Label: ast.NewIdent(getLblNameFromSubStateNum(1, subStateNum)), }, }, ), ) var stmtToInsert ast.Stmt = &ast.LabeledStmt{ - Label: ast.NewIdent(getLblNameFromSubStateNum(subStateNum)), + Label: ast.NewIdent(getLblNameFromSubStateNum(1, subStateNum)), Stmt: &ast.EmptyStmt{}, } funcDecl.Body.List = insertIntoArr(funcDecl.Body.List, i+1, stmtToInsert) @@ -257,33 +452,33 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { continue } - // Find functions calls in the style of 'xyz.ABC123()' - exprStmt, exprStmtOk := stmt.(*ast.ExprStmt) - if !exprStmtOk { + // Find cogo function call in the style of 'cogo.Xyz()' + exprStmt, ok := stmt.(*ast.ExprStmt) + if !ok { continue } - callExpr, exprStmtOk := exprStmt.X.(*ast.CallExpr) - if !exprStmtOk { + callExpr, ok := exprStmt.X.(*ast.CallExpr) + if !ok { continue } - cogoFuncSelExpr, exprStmtOk = callExpr.Fun.(*ast.SelectorExpr) - if !exprStmtOk { + selExpr, ok = callExpr.Fun.(*ast.SelectorExpr) + if !ok { continue } - if !selExprHasLhsName(cogoFuncSelExpr, coroutineParamName) { + if !selExprHasLhsName(selExpr, coroutineParamName) { continue } // Now that we found a call to cogo decide what to do - if cogoFuncSelExpr.Sel.Name == "Begin" { + if selExpr.Sel.Name == "Begin" { beginBodyListIndex = i lastCaseEndBodyListIndex = i continue - } else if cogoFuncSelExpr.Sel.Name == "Yield" { + } else if selExpr.Sel.Name == "Yield" { // Add everything from the last begin/yield until this yield into a case stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i] @@ -319,16 +514,16 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { &ast.ReturnStmt{}, ) - switchStmt.Body.List = append(switchStmt.Body.List, + mainSwitchStmt.Body.List = append(mainSwitchStmt.Body.List, getCaseWithStmts( - []ast.Expr{ast.NewIdent(fmt.Sprint(len(switchStmt.Body.List)))}, + []ast.Expr{ast.NewIdent(fmt.Sprint(len(mainSwitchStmt.Body.List)))}, caseStmts, ), ) lastCaseEndBodyListIndex = i - } else if cogoFuncSelExpr.Sel.Name == "YieldTo" { + } else if selExpr.Sel.Name == "YieldTo" { // Add everything from the last begin/yield until this yield into a case stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i] @@ -364,15 +559,15 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { &ast.ReturnStmt{}, ) - switchStmt.Body.List = append(switchStmt.Body.List, + mainSwitchStmt.Body.List = append(mainSwitchStmt.Body.List, getCaseWithStmts( - []ast.Expr{ast.NewIdent(fmt.Sprint(len(switchStmt.Body.List)))}, + []ast.Expr{ast.NewIdent(fmt.Sprint(len(mainSwitchStmt.Body.List)))}, caseStmts, ), ) lastCaseEndBodyListIndex = i - } else if cogoFuncSelExpr.Sel.Name == "YieldNone" { + } else if selExpr.Sel.Name == "YieldNone" { // Add everything from the last begin/yield until this yield into a case stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i] @@ -403,9 +598,9 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { &ast.ReturnStmt{}, ) - switchStmt.Body.List = append(switchStmt.Body.List, + mainSwitchStmt.Body.List = append(mainSwitchStmt.Body.List, getCaseWithStmts( - []ast.Expr{ast.NewIdent(fmt.Sprint(len(switchStmt.Body.List)))}, + []ast.Expr{ast.NewIdent(fmt.Sprint(len(mainSwitchStmt.Body.List)))}, caseStmts, ), ) @@ -443,9 +638,9 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { ) caseStmts = append(caseStmts, stmtsToEndOfFunc...) - switchStmt.Body.List = append(switchStmt.Body.List, + mainSwitchStmt.Body.List = append(mainSwitchStmt.Body.List, getCaseWithStmts( - []ast.Expr{ast.NewIdent(fmt.Sprint(len(switchStmt.Body.List)))}, + []ast.Expr{ast.NewIdent(fmt.Sprint(len(mainSwitchStmt.Body.List)))}, caseStmts, ), @@ -471,7 +666,7 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { // Apply changes funcDecl.Body.List = funcDecl.Body.List[:beginBodyListIndex] funcDecl.Body.List = append(funcDecl.Body.List, - switchStmt, + mainSwitchStmt, ) originalList := funcDecl.Body.List @@ -586,8 +781,8 @@ func (p *processor) genCogoBlockStmt(blockStmt *ast.BlockStmt, coroutineParamNam return subStateNums } -func getLblNameFromSubStateNum(subStateNum int32) string { - return fmt.Sprint("cogo_", subStateNum) +func getLblNameFromSubStateNum(switchNum, caseNum int32) string { + return fmt.Sprintf("cogo_%d_%d", switchNum, caseNum) } func insertIntoArr[T any](a []T, index int, value T) []T { @@ -639,7 +834,7 @@ func (p *processor) genHasGenChecksOnOriginalFuncsNodeProcessor(c *astutil.Curso } coroutineParamName := getCoroutineParamNameFromFuncDecl(funcDecl) - if coroutineParamName == "" || !funcHasSelInBody(funcDecl, coroutineParamName, "Begin") { + if coroutineParamName == "" || !blockHasOneOrMoreSels(funcDecl.Body, []SelExprInfo{{coroutineParamName, "Begin"}}, false) { return false } @@ -707,13 +902,33 @@ func createStmtFromSelFuncCall(lhs, rhs string) ast.Stmt { } } -func funcHasSelInBody(fd *ast.FuncDecl, selLhs, selRhs string) bool { +type SelExprInfo struct { + // Give `cogo.Yield()`, Lhs would be 'cogo' + Lhs string + // Give `cogo.Yield()`, Rhs would be 'Yield' + Rhs string +} - if fd.Body == nil || len(fd.Body.List) == 0 { +func blockUsesCogo(block *ast.BlockStmt, coroutineParamName string, checkChildBlocks bool) bool { + return blockHasOneOrMoreSels(block, []SelExprInfo{{coroutineParamName, "Yield"}, {coroutineParamName, "YieldTo"}, {coroutineParamName, "YieldNone"}}, checkChildBlocks) +} + +func blockHasOneOrMoreSels(block *ast.BlockStmt, sels []SelExprInfo, checkChildBlocks bool) bool { + + if block == nil || len(block.List) == 0 { return false } - for _, stmt := range fd.Body.List { + for _, stmt := range block.List { + + // Check recursively if requested + if checkChildBlocks { + if blockStmt, ok := stmt.(*ast.BlockStmt); ok { + if blockHasOneOrMoreSels(blockStmt, sels, true) { + return true + } + } + } // Find functions calls in the style of 'cogo.ABC123()' exprStmt, ok := stmt.(*ast.ExprStmt) @@ -731,8 +946,10 @@ func funcHasSelInBody(fd *ast.FuncDecl, selLhs, selRhs string) bool { continue } - if selExprIs(selExpr, selLhs, selRhs) { - return true + for _, v := range sels { + if selExprIs(selExpr, v.Lhs, v.Rhs) { + return true + } } } @@ -789,10 +1006,10 @@ func selExprIs(selExpr *ast.SelectorExpr, lhs, rhs string) bool { return pkgIdentExpr.Name == lhs && selExpr.Sel.Name == rhs } -func getCaseWithStmts(caseConditions []ast.Expr, stmts []ast.Stmt) *ast.CaseClause { +func getCaseWithStmts(caseConditions []ast.Expr, bodyStmts []ast.Stmt) *ast.CaseClause { return &ast.CaseClause{ List: caseConditions, - Body: stmts, + Body: bodyStmts, } }