New much better algorithm

This commit is contained in:
bloeys
2023-01-20 06:58:22 +04:00
parent ca355b2e06
commit 3ca3cf1d32
3 changed files with 346 additions and 148 deletions

View File

@ -9,97 +9,46 @@ import (
func test_cogo(c *cogo.Coroutine[int, int]) { func test_cogo(c *cogo.Coroutine[int, int]) {
switch c.State { switch c.State {
case 0: case 1:
switch c.SubState { goto cogo_1_1
default: case 2:
} c.State = 1
goto cogo_2_0
case 3:
goto cogo_1_3
}
println("test yield:", 1) println("test yield:", 1)
c.State++ {
c.SubState = -1 c.State = 1
c.Out = 1 c.Out = 1
return return
case 1: }
switch c.SubState { cogo_1_1:
default: ;
} cogo_2_0:
c.State++ ;
c.SubState = -1
c.Yielder = cogo.NewSleeper(100 * time.Millisecond)
return
case 2:
switch c.SubState {
default:
}
c.State++
c.SubState = -1
c.Yielder = cogo.New(test2, 0)
return
case 3:
switch c.SubState {
default:
}
println("test yield:", 2) if c.Out > 2 {
c.State++ switch c.State {
c.SubState = -1 case 1:
goto cogo_2_1
}
{
c.State = 2
c.Out = 1
return
}
cogo_2_1:
}
c.YieldTo(cogo.NewSleeper(100 * time.Millisecond))
println("test yield:", 2)
{
c.State = 3
c.Out = 2 c.Out = 2
return return
case 4:
switch c.SubState {
default:
}
c.State = -1
c.SubState = -1
default:
c.State = -1
c.SubState = -1
return
}
}
func test2_cogo(c *cogo.Coroutine[int, int]) {
switch c.State {
case 0:
switch c.SubState {
default:
}
println("test2222 yield:", 1)
c.State++
c.SubState = -1
c.Out = 1
return
case 1:
switch c.SubState {
default:
}
println("test2222 yield:", 2)
c.State++
c.SubState = -1
c.Out = 2
return
case 2:
switch c.SubState {
default:
}
println("test2222 before yield none")
c.State++
c.SubState = -1
return
case 3:
switch c.SubState {
default:
}
c.State = -1
c.SubState = -1
println("test2222 after yield none")
default:
c.State = -1
c.SubState = -1
return
} }
cogo_1_3:
} }

70
demo.go
View File

@ -23,41 +23,73 @@ func runDemo() {
} }
func test(c *cogo.Coroutine[int, int]) { func test(c *cogo.Coroutine[int, int]) {
if cogo.HasGen() {
test_cogo(c)
return
}
c.Begin()
println("test yield:", 1) println("test yield:", 1)
c.Yield(1) c.Yield(1)
if c.Out > 2 {
c.Yield(1)
}
// Yield here until at least 100ms passed // Yield here until at least 100ms passed
c.YieldTo(cogo.NewSleeper(100 * time.Millisecond)) c.YieldTo(cogo.NewSleeper(100 * time.Millisecond))
// Yield here until the coroutine 'test2' has finished // Yield here until the coroutine 'test2' has finished
c.YieldTo(cogo.New(test2, 0)) // c.YieldTo(cogo.New(test2, 0))
println("test yield:", 2) println("test yield:", 2)
c.Yield(2) c.Yield(2)
} }
func test2(c *cogo.Coroutine[int, int]) { // func test2(c *cogo.Coroutine[int, int]) {
if cogo.HasGen() {
test2_cogo(c) // println("test2222 yield:", 1)
return // c.Yield(1)
// println("test2222 yield:", 2)
// c.Yield(2)
// println("test2222 before yield none")
// c.YieldNone()
// println("test2222 after yield none")
// }
func NewApproach(state int) {
switch state {
case 1:
goto lbl_1
case 2:
goto lbl_2
case 3:
state = 1
goto lbl_2
default:
} }
c.Begin() println("1")
println("2")
state = 1
// return
println("test2222 yield:", 1) lbl_1:
c.Yield(1) println("3")
state = 2
// return
println("test2222 yield:", 2) lbl_2:
c.Yield(2) {
switch state {
case 1:
goto lbl_3
println("test2222 before yield none") default:
c.YieldNone() }
println("test2222 after yield none")
println("4")
state = 3
// return
lbl_3:
}
} }

303
main.go
View File

@ -35,7 +35,7 @@ func main() {
} }
genCogoFuncs(cwd) genCogoFuncs(cwd)
genHasGenChecksOnOriginalFuncs(cwd) // genHasGenChecksOnOriginalFuncs(cwd)
} }
func genCogoFuncs(cwd string) { func genCogoFuncs(cwd string) {
@ -54,10 +54,12 @@ func genCogoFuncs(cwd string) {
p := &processor{ p := &processor{
fset: pkg.Fset, fset: pkg.Fset,
funcDeclsToWrite: []*ast.FuncDecl{}, funcDeclsToWrite: []*ast.FuncDecl{},
BlockInfos: make([]BlockInfo, 0, 10),
} }
for i, synFile := range pkg.Syntax { for i, synFile := range pkg.Syntax {
pkg.Syntax[i] = astutil.Apply(synFile, p.genCogoFuncsNodeProcessor, nil).(*ast.File) pkg.Syntax[i] = astutil.Apply(synFile, p.nodeProcessor, nil).(*ast.File)
// pkg.Syntax[i] = astutil.Apply(synFile, p.genCogoFuncsNodeProcessor, nil).(*ast.File)
if len(p.funcDeclsToWrite) > 0 { if len(p.funcDeclsToWrite) > 0 {
@ -101,6 +103,7 @@ func genHasGenChecksOnOriginalFuncs(cwd string) {
p := &processor{ p := &processor{
fset: pkg.Fset, fset: pkg.Fset,
funcDeclsToWrite: []*ast.FuncDecl{}, funcDeclsToWrite: []*ast.FuncDecl{},
BlockInfos: make([]BlockInfo, 0, 10),
} }
for i, synFile := range pkg.Syntax { for i, synFile := range pkg.Syntax {
@ -118,9 +121,201 @@ func genHasGenChecksOnOriginalFuncs(cwd string) {
} }
type BlockInfo struct {
Switch *ast.SwitchStmt
BeforeBlockLblName string
}
func (s *BlockInfo) addCase(caseStmts []ast.Stmt) int {
oldCaseCount := s.CaseCount()
newCase := getCaseWithStmts([]ast.Expr{
ast.NewIdent(fmt.Sprint(oldCaseCount + 1))},
caseStmts,
)
s.Switch.Body.List = append(s.Switch.Body.List, newCase)
return len(s.Switch.Body.List)
}
func (s *BlockInfo) CaseCount() int {
return len(s.Switch.Body.List)
}
type processor struct { type processor struct {
fset *token.FileSet fset *token.FileSet
funcDeclsToWrite []*ast.FuncDecl funcDeclsToWrite []*ast.FuncDecl
BlockInfos []BlockInfo
}
func (p *processor) nodeProcessor(c *astutil.Cursor) bool {
n := c.Node()
if n == nil {
return false
}
// If not a function or it's empty skip and continue reading the file AST
funcDecl, ok := n.(*ast.FuncDecl)
if !ok || funcDecl.Body == nil || len(funcDecl.Body.List) == 0 {
return true
}
// Check if function has the required params
coroutineParamName := getCoroutineParamNameFromFuncDecl(funcDecl)
if coroutineParamName == "" {
return false
}
if !blockUsesCogo(funcDecl.Body, coroutineParamName, true) {
return false
}
// Generate code for function
p.processBlock(nil, funcDecl.Body, -1, coroutineParamName)
p.funcDeclsToWrite = append(p.funcDeclsToWrite, funcDecl)
return false
}
func (p *processor) processBlock(parentBlock, blockStmt *ast.BlockStmt, indexInParent int, coroutineParamName string) {
if !blockUsesCogo(blockStmt, coroutineParamName, true) {
return
}
p.pushNewBlock(parentBlock, blockStmt, indexInParent, &ast.SwitchStmt{
Tag: ast.NewIdent(coroutineParamName + ".State"),
Body: &ast.BlockStmt{
List: []ast.Stmt{},
},
})
defer p.popSwitch()
for i := 0; i < len(blockStmt.List); i++ {
stmt := blockStmt.List[i]
if ifStmt, ifStmtOk := stmt.(*ast.IfStmt); ifStmtOk {
p.processBlock(blockStmt, ifStmt.Body, i, coroutineParamName)
} else if forStmt, forStmtOk := stmt.(*ast.ForStmt); forStmtOk {
// @TODO: For loops need unique handling to convert to an if statement (if they use cogo)
p.processBlock(blockStmt, forStmt.Body, i, coroutineParamName)
} else if bStmt, blockStmtOk := stmt.(*ast.BlockStmt); blockStmtOk {
p.processBlock(blockStmt, bStmt, i, coroutineParamName)
}
selExpr, selExprArgs := tryGetSelExprFromStmt(stmt, coroutineParamName, "Yield")
if selExpr == nil {
continue
}
p.addYield(blockStmt, &i, selExprArgs, coroutineParamName)
}
}
func (p *processor) pushNewBlock(parentBlock, blockStmt *ast.BlockStmt, indexInParent int, switchStmt *ast.SwitchStmt) {
beforeBlockLblName := getLblNameFromSubStateNum(int32(len(p.BlockInfos)+1), 0)
if parentBlock != nil {
parentBlock.List = insertIntoArr[ast.Stmt](parentBlock.List, indexInParent, &ast.LabeledStmt{
Label: ast.NewIdent(beforeBlockLblName),
Stmt: &ast.EmptyStmt{},
})
}
// Add switch to start of block
blockStmt.List = insertIntoArr[ast.Stmt](blockStmt.List, 0, switchStmt)
p.BlockInfos = append(p.BlockInfos, BlockInfo{
Switch: switchStmt,
BeforeBlockLblName: beforeBlockLblName,
})
}
func (p *processor) currSwitch() *ast.SwitchStmt {
return p.BlockInfos[len(p.BlockInfos)-1].Switch
}
func (p *processor) currSwitchInfo() *BlockInfo {
return &p.BlockInfos[len(p.BlockInfos)-1]
}
func (p *processor) popSwitch() *ast.SwitchStmt {
s := p.BlockInfos[len(p.BlockInfos)-1]
p.BlockInfos = p.BlockInfos[:len(p.BlockInfos)-1]
return s.Switch
}
func toStr[T any](x T) string {
return fmt.Sprintf("%+v", x)
}
func (p *processor) addYield(block *ast.BlockStmt, listIndex *int, yieldArgs []ast.Expr, coroutineParamName string) {
// Update switch statements and find the new state value
newCaseCondition := p.currSwitchInfo().CaseCount() + 1
newLblName := getLblNameFromSubStateNum(int32(len(p.BlockInfos)), int32(newCaseCondition))
// Go over all switch cases from current inner block to outer most
// and add a new case in each to make us reach this new yield
p.currSwitchInfo().addCase([]ast.Stmt{
&ast.BranchStmt{
Tok: token.GOTO,
Label: ast.NewIdent(newLblName),
},
})
stateValNeededForNexSwitch := newCaseCondition
prevBlockBeforeLblName := p.currSwitchInfo().BeforeBlockLblName
for i := len(p.BlockInfos) - 2; i >= 0; i-- {
info := &p.BlockInfos[i]
info.addCase([]ast.Stmt{
&ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".State")},
Tok: token.ASSIGN,
Rhs: []ast.Expr{ast.NewIdent(toStr(stateValNeededForNexSwitch))},
},
&ast.BranchStmt{
Tok: token.GOTO,
Label: ast.NewIdent(prevBlockBeforeLblName),
},
})
stateValNeededForNexSwitch = info.CaseCount()
prevBlockBeforeLblName = info.BeforeBlockLblName
}
// Create and add yield block
newBlock := &ast.BlockStmt{
List: []ast.Stmt{
&ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".State")},
Tok: token.ASSIGN,
Rhs: []ast.Expr{ast.NewIdent(toStr(stateValNeededForNexSwitch))},
},
&ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".Out")},
Tok: token.ASSIGN,
Rhs: yieldArgs,
},
&ast.ReturnStmt{},
},
}
block.List[*listIndex] = newBlock
block.List = insertIntoArr[ast.Stmt](block.List, *listIndex+1, &ast.LabeledStmt{
Label: ast.NewIdent(newLblName),
Stmt: &ast.EmptyStmt{},
})
*listIndex++
} }
func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
@ -140,8 +335,8 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
return false return false
} }
hasBegin := funcHasSelInBody(funcDecl, coroutineParamName, "Begin") hasBegin := blockHasOneOrMoreSels(funcDecl.Body, []SelExprInfo{{coroutineParamName, "Begin"}}, false)
hasYield := funcHasSelInBody(funcDecl, coroutineParamName, "Yield") hasYield := blockHasOneOrMoreSels(funcDecl.Body, []SelExprInfo{{coroutineParamName, "Yield"}}, false)
if !hasBegin && !hasYield { if !hasBegin && !hasYield {
return false return false
} }
@ -152,7 +347,7 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
beginBodyListIndex := -1 beginBodyListIndex := -1
lastCaseEndBodyListIndex := -1 lastCaseEndBodyListIndex := -1
switchStmt := &ast.SwitchStmt{ mainSwitchStmt := &ast.SwitchStmt{
Tag: ast.NewIdent(coroutineParamName + ".State"), Tag: ast.NewIdent(coroutineParamName + ".State"),
Body: &ast.BlockStmt{ Body: &ast.BlockStmt{
List: []ast.Stmt{}, List: []ast.Stmt{},
@ -172,8 +367,8 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
stmt := funcDecl.Body.List[i] stmt := funcDecl.Body.List[i]
var cogoFuncSelExpr *ast.SelectorExpr
var blockStmt *ast.BlockStmt var blockStmt *ast.BlockStmt
var selExpr *ast.SelectorExpr
if ifStmt, ifStmtOk := stmt.(*ast.IfStmt); ifStmtOk { if ifStmt, ifStmtOk := stmt.(*ast.IfStmt); ifStmtOk {
@ -186,7 +381,7 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
} else if forStmt, forStmtOk := stmt.(*ast.ForStmt); forStmtOk { } else if forStmt, forStmtOk := stmt.(*ast.ForStmt); forStmtOk {
outInitBlock, postInitStmt, outIfStmt, subStateNums := p.genCogoForStmt(forStmt, coroutineParamName, len(switchStmt.Body.List)) outInitBlock, postInitStmt, outIfStmt, subStateNums := p.genCogoForStmt(forStmt, coroutineParamName, len(mainSwitchStmt.Body.List))
if len(subStateNums) > 1 { if len(subStateNums) > 1 {
panic("For loops currently don't support more than one yield") panic("For loops currently don't support more than one yield")
@ -202,7 +397,7 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
postInitStmt, postInitStmt,
&ast.BranchStmt{ &ast.BranchStmt{
Tok: token.GOTO, Tok: token.GOTO,
Label: ast.NewIdent(getLblNameFromSubStateNum(subStateNums[0])), Label: ast.NewIdent(getLblNameFromSubStateNum(1, subStateNums[0])),
}, },
}, },
), ),
@ -213,7 +408,7 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
// // Insert lable after initial condition // // Insert lable after initial condition
var stmtInterface ast.Stmt = &ast.LabeledStmt{ var stmtInterface ast.Stmt = &ast.LabeledStmt{
Label: ast.NewIdent(getLblNameFromSubStateNum(subStateNums[0])), Label: ast.NewIdent(getLblNameFromSubStateNum(1, subStateNums[0])),
Stmt: &ast.EmptyStmt{}, Stmt: &ast.EmptyStmt{},
} }
funcDecl.Body.List = insertIntoArr(funcDecl.Body.List, i+1, stmtInterface) funcDecl.Body.List = insertIntoArr(funcDecl.Body.List, i+1, stmtInterface)
@ -231,7 +426,7 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
if blockStmt != nil { if blockStmt != nil {
subStateNums := p.genCogoBlockStmt(blockStmt, coroutineParamName, len(switchStmt.Body.List)) subStateNums := p.genCogoBlockStmt(blockStmt, coroutineParamName, len(mainSwitchStmt.Body.List))
for _, subStateNum := range subStateNums { for _, subStateNum := range subStateNums {
subSwitchStmt.Body.List = append(subSwitchStmt.Body.List, subSwitchStmt.Body.List = append(subSwitchStmt.Body.List,
@ -240,14 +435,14 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
[]ast.Stmt{ []ast.Stmt{
&ast.BranchStmt{ &ast.BranchStmt{
Tok: token.GOTO, Tok: token.GOTO,
Label: ast.NewIdent(getLblNameFromSubStateNum(subStateNum)), Label: ast.NewIdent(getLblNameFromSubStateNum(1, subStateNum)),
}, },
}, },
), ),
) )
var stmtToInsert ast.Stmt = &ast.LabeledStmt{ var stmtToInsert ast.Stmt = &ast.LabeledStmt{
Label: ast.NewIdent(getLblNameFromSubStateNum(subStateNum)), Label: ast.NewIdent(getLblNameFromSubStateNum(1, subStateNum)),
Stmt: &ast.EmptyStmt{}, Stmt: &ast.EmptyStmt{},
} }
funcDecl.Body.List = insertIntoArr(funcDecl.Body.List, i+1, stmtToInsert) funcDecl.Body.List = insertIntoArr(funcDecl.Body.List, i+1, stmtToInsert)
@ -257,33 +452,33 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
continue continue
} }
// Find functions calls in the style of 'xyz.ABC123()' // Find cogo function call in the style of 'cogo.Xyz()'
exprStmt, exprStmtOk := stmt.(*ast.ExprStmt) exprStmt, ok := stmt.(*ast.ExprStmt)
if !exprStmtOk { if !ok {
continue continue
} }
callExpr, exprStmtOk := exprStmt.X.(*ast.CallExpr) callExpr, ok := exprStmt.X.(*ast.CallExpr)
if !exprStmtOk { if !ok {
continue continue
} }
cogoFuncSelExpr, exprStmtOk = callExpr.Fun.(*ast.SelectorExpr) selExpr, ok = callExpr.Fun.(*ast.SelectorExpr)
if !exprStmtOk { if !ok {
continue continue
} }
if !selExprHasLhsName(cogoFuncSelExpr, coroutineParamName) { if !selExprHasLhsName(selExpr, coroutineParamName) {
continue continue
} }
// Now that we found a call to cogo decide what to do // Now that we found a call to cogo decide what to do
if cogoFuncSelExpr.Sel.Name == "Begin" { if selExpr.Sel.Name == "Begin" {
beginBodyListIndex = i beginBodyListIndex = i
lastCaseEndBodyListIndex = i lastCaseEndBodyListIndex = i
continue continue
} else if cogoFuncSelExpr.Sel.Name == "Yield" { } else if selExpr.Sel.Name == "Yield" {
// Add everything from the last begin/yield until this yield into a case // Add everything from the last begin/yield until this yield into a case
stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i] stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i]
@ -319,16 +514,16 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
&ast.ReturnStmt{}, &ast.ReturnStmt{},
) )
switchStmt.Body.List = append(switchStmt.Body.List, mainSwitchStmt.Body.List = append(mainSwitchStmt.Body.List,
getCaseWithStmts( getCaseWithStmts(
[]ast.Expr{ast.NewIdent(fmt.Sprint(len(switchStmt.Body.List)))}, []ast.Expr{ast.NewIdent(fmt.Sprint(len(mainSwitchStmt.Body.List)))},
caseStmts, caseStmts,
), ),
) )
lastCaseEndBodyListIndex = i lastCaseEndBodyListIndex = i
} else if cogoFuncSelExpr.Sel.Name == "YieldTo" { } else if selExpr.Sel.Name == "YieldTo" {
// Add everything from the last begin/yield until this yield into a case // Add everything from the last begin/yield until this yield into a case
stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i] stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i]
@ -364,15 +559,15 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
&ast.ReturnStmt{}, &ast.ReturnStmt{},
) )
switchStmt.Body.List = append(switchStmt.Body.List, mainSwitchStmt.Body.List = append(mainSwitchStmt.Body.List,
getCaseWithStmts( getCaseWithStmts(
[]ast.Expr{ast.NewIdent(fmt.Sprint(len(switchStmt.Body.List)))}, []ast.Expr{ast.NewIdent(fmt.Sprint(len(mainSwitchStmt.Body.List)))},
caseStmts, caseStmts,
), ),
) )
lastCaseEndBodyListIndex = i lastCaseEndBodyListIndex = i
} else if cogoFuncSelExpr.Sel.Name == "YieldNone" { } else if selExpr.Sel.Name == "YieldNone" {
// Add everything from the last begin/yield until this yield into a case // Add everything from the last begin/yield until this yield into a case
stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i] stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i]
@ -403,9 +598,9 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
&ast.ReturnStmt{}, &ast.ReturnStmt{},
) )
switchStmt.Body.List = append(switchStmt.Body.List, mainSwitchStmt.Body.List = append(mainSwitchStmt.Body.List,
getCaseWithStmts( getCaseWithStmts(
[]ast.Expr{ast.NewIdent(fmt.Sprint(len(switchStmt.Body.List)))}, []ast.Expr{ast.NewIdent(fmt.Sprint(len(mainSwitchStmt.Body.List)))},
caseStmts, caseStmts,
), ),
) )
@ -443,9 +638,9 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
) )
caseStmts = append(caseStmts, stmtsToEndOfFunc...) caseStmts = append(caseStmts, stmtsToEndOfFunc...)
switchStmt.Body.List = append(switchStmt.Body.List, mainSwitchStmt.Body.List = append(mainSwitchStmt.Body.List,
getCaseWithStmts( getCaseWithStmts(
[]ast.Expr{ast.NewIdent(fmt.Sprint(len(switchStmt.Body.List)))}, []ast.Expr{ast.NewIdent(fmt.Sprint(len(mainSwitchStmt.Body.List)))},
caseStmts, caseStmts,
), ),
@ -471,7 +666,7 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
// Apply changes // Apply changes
funcDecl.Body.List = funcDecl.Body.List[:beginBodyListIndex] funcDecl.Body.List = funcDecl.Body.List[:beginBodyListIndex]
funcDecl.Body.List = append(funcDecl.Body.List, funcDecl.Body.List = append(funcDecl.Body.List,
switchStmt, mainSwitchStmt,
) )
originalList := funcDecl.Body.List originalList := funcDecl.Body.List
@ -586,8 +781,8 @@ func (p *processor) genCogoBlockStmt(blockStmt *ast.BlockStmt, coroutineParamNam
return subStateNums return subStateNums
} }
func getLblNameFromSubStateNum(subStateNum int32) string { func getLblNameFromSubStateNum(switchNum, caseNum int32) string {
return fmt.Sprint("cogo_", subStateNum) return fmt.Sprintf("cogo_%d_%d", switchNum, caseNum)
} }
func insertIntoArr[T any](a []T, index int, value T) []T { func insertIntoArr[T any](a []T, index int, value T) []T {
@ -639,7 +834,7 @@ func (p *processor) genHasGenChecksOnOriginalFuncsNodeProcessor(c *astutil.Curso
} }
coroutineParamName := getCoroutineParamNameFromFuncDecl(funcDecl) coroutineParamName := getCoroutineParamNameFromFuncDecl(funcDecl)
if coroutineParamName == "" || !funcHasSelInBody(funcDecl, coroutineParamName, "Begin") { if coroutineParamName == "" || !blockHasOneOrMoreSels(funcDecl.Body, []SelExprInfo{{coroutineParamName, "Begin"}}, false) {
return false return false
} }
@ -707,13 +902,33 @@ func createStmtFromSelFuncCall(lhs, rhs string) ast.Stmt {
} }
} }
func funcHasSelInBody(fd *ast.FuncDecl, selLhs, selRhs string) bool { type SelExprInfo struct {
// Give `cogo.Yield()`, Lhs would be 'cogo'
Lhs string
// Give `cogo.Yield()`, Rhs would be 'Yield'
Rhs string
}
if fd.Body == nil || len(fd.Body.List) == 0 { func blockUsesCogo(block *ast.BlockStmt, coroutineParamName string, checkChildBlocks bool) bool {
return blockHasOneOrMoreSels(block, []SelExprInfo{{coroutineParamName, "Yield"}, {coroutineParamName, "YieldTo"}, {coroutineParamName, "YieldNone"}}, checkChildBlocks)
}
func blockHasOneOrMoreSels(block *ast.BlockStmt, sels []SelExprInfo, checkChildBlocks bool) bool {
if block == nil || len(block.List) == 0 {
return false return false
} }
for _, stmt := range fd.Body.List { for _, stmt := range block.List {
// Check recursively if requested
if checkChildBlocks {
if blockStmt, ok := stmt.(*ast.BlockStmt); ok {
if blockHasOneOrMoreSels(blockStmt, sels, true) {
return true
}
}
}
// Find functions calls in the style of 'cogo.ABC123()' // Find functions calls in the style of 'cogo.ABC123()'
exprStmt, ok := stmt.(*ast.ExprStmt) exprStmt, ok := stmt.(*ast.ExprStmt)
@ -731,8 +946,10 @@ func funcHasSelInBody(fd *ast.FuncDecl, selLhs, selRhs string) bool {
continue continue
} }
if selExprIs(selExpr, selLhs, selRhs) { for _, v := range sels {
return true if selExprIs(selExpr, v.Lhs, v.Rhs) {
return true
}
} }
} }
@ -789,10 +1006,10 @@ func selExprIs(selExpr *ast.SelectorExpr, lhs, rhs string) bool {
return pkgIdentExpr.Name == lhs && selExpr.Sel.Name == rhs return pkgIdentExpr.Name == lhs && selExpr.Sel.Name == rhs
} }
func getCaseWithStmts(caseConditions []ast.Expr, stmts []ast.Stmt) *ast.CaseClause { func getCaseWithStmts(caseConditions []ast.Expr, bodyStmts []ast.Stmt) *ast.CaseClause {
return &ast.CaseClause{ return &ast.CaseClause{
List: caseConditions, List: caseConditions,
Body: stmts, Body: bodyStmts,
} }
} }