mirror of
https://github.com/bloeys/cogo.git
synced 2025-12-29 08:58:19 +00:00
New much better algorithm
This commit is contained in:
105
demo.cogo.go
105
demo.cogo.go
@ -9,97 +9,46 @@ import (
|
|||||||
|
|
||||||
func test_cogo(c *cogo.Coroutine[int, int]) {
|
func test_cogo(c *cogo.Coroutine[int, int]) {
|
||||||
switch c.State {
|
switch c.State {
|
||||||
case 0:
|
case 1:
|
||||||
switch c.SubState {
|
goto cogo_1_1
|
||||||
default:
|
case 2:
|
||||||
|
c.State = 1
|
||||||
|
goto cogo_2_0
|
||||||
|
case 3:
|
||||||
|
goto cogo_1_3
|
||||||
}
|
}
|
||||||
|
|
||||||
println("test yield:", 1)
|
println("test yield:", 1)
|
||||||
c.State++
|
{
|
||||||
c.SubState = -1
|
c.State = 1
|
||||||
c.Out = 1
|
c.Out = 1
|
||||||
return
|
return
|
||||||
|
}
|
||||||
|
cogo_1_1:
|
||||||
|
;
|
||||||
|
cogo_2_0:
|
||||||
|
;
|
||||||
|
|
||||||
|
if c.Out > 2 {
|
||||||
|
switch c.State {
|
||||||
case 1:
|
case 1:
|
||||||
switch c.SubState {
|
goto cogo_2_1
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
c.State++
|
{
|
||||||
c.SubState = -1
|
c.State = 2
|
||||||
c.Yielder = cogo.NewSleeper(100 * time.Millisecond)
|
c.Out = 1
|
||||||
return
|
return
|
||||||
case 2:
|
|
||||||
switch c.SubState {
|
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
c.State++
|
cogo_2_1:
|
||||||
c.SubState = -1
|
|
||||||
c.Yielder = cogo.New(test2, 0)
|
|
||||||
return
|
|
||||||
case 3:
|
|
||||||
switch c.SubState {
|
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.YieldTo(cogo.NewSleeper(100 * time.Millisecond))
|
||||||
|
|
||||||
println("test yield:", 2)
|
println("test yield:", 2)
|
||||||
c.State++
|
{
|
||||||
c.SubState = -1
|
c.State = 3
|
||||||
c.Out = 2
|
c.Out = 2
|
||||||
return
|
return
|
||||||
case 4:
|
|
||||||
switch c.SubState {
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
c.State = -1
|
|
||||||
c.SubState = -1
|
|
||||||
default:
|
|
||||||
c.State = -1
|
|
||||||
c.SubState = -1
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func test2_cogo(c *cogo.Coroutine[int, int]) {
|
|
||||||
switch c.State {
|
|
||||||
case 0:
|
|
||||||
switch c.SubState {
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
println("test2222 yield:", 1)
|
|
||||||
c.State++
|
|
||||||
c.SubState = -1
|
|
||||||
c.Out = 1
|
|
||||||
return
|
|
||||||
case 1:
|
|
||||||
switch c.SubState {
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
println("test2222 yield:", 2)
|
|
||||||
c.State++
|
|
||||||
c.SubState = -1
|
|
||||||
c.Out = 2
|
|
||||||
return
|
|
||||||
case 2:
|
|
||||||
switch c.SubState {
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
println("test2222 before yield none")
|
|
||||||
c.State++
|
|
||||||
c.SubState = -1
|
|
||||||
return
|
|
||||||
case 3:
|
|
||||||
switch c.SubState {
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
c.State = -1
|
|
||||||
c.SubState = -1
|
|
||||||
|
|
||||||
println("test2222 after yield none")
|
|
||||||
default:
|
|
||||||
c.State = -1
|
|
||||||
c.SubState = -1
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
cogo_1_3:
|
||||||
}
|
}
|
||||||
|
|||||||
70
demo.go
70
demo.go
@ -23,41 +23,73 @@ func runDemo() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func test(c *cogo.Coroutine[int, int]) {
|
func test(c *cogo.Coroutine[int, int]) {
|
||||||
if cogo.HasGen() {
|
|
||||||
test_cogo(c)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Begin()
|
|
||||||
|
|
||||||
println("test yield:", 1)
|
println("test yield:", 1)
|
||||||
c.Yield(1)
|
c.Yield(1)
|
||||||
|
|
||||||
|
if c.Out > 2 {
|
||||||
|
c.Yield(1)
|
||||||
|
}
|
||||||
|
|
||||||
// Yield here until at least 100ms passed
|
// Yield here until at least 100ms passed
|
||||||
c.YieldTo(cogo.NewSleeper(100 * time.Millisecond))
|
c.YieldTo(cogo.NewSleeper(100 * time.Millisecond))
|
||||||
|
|
||||||
// Yield here until the coroutine 'test2' has finished
|
// Yield here until the coroutine 'test2' has finished
|
||||||
c.YieldTo(cogo.New(test2, 0))
|
// c.YieldTo(cogo.New(test2, 0))
|
||||||
|
|
||||||
println("test yield:", 2)
|
println("test yield:", 2)
|
||||||
c.Yield(2)
|
c.Yield(2)
|
||||||
}
|
}
|
||||||
|
|
||||||
func test2(c *cogo.Coroutine[int, int]) {
|
// func test2(c *cogo.Coroutine[int, int]) {
|
||||||
if cogo.HasGen() {
|
|
||||||
test2_cogo(c)
|
// println("test2222 yield:", 1)
|
||||||
return
|
// c.Yield(1)
|
||||||
|
|
||||||
|
// println("test2222 yield:", 2)
|
||||||
|
// c.Yield(2)
|
||||||
|
|
||||||
|
// println("test2222 before yield none")
|
||||||
|
// c.YieldNone()
|
||||||
|
// println("test2222 after yield none")
|
||||||
|
// }
|
||||||
|
|
||||||
|
func NewApproach(state int) {
|
||||||
|
|
||||||
|
switch state {
|
||||||
|
case 1:
|
||||||
|
goto lbl_1
|
||||||
|
case 2:
|
||||||
|
goto lbl_2
|
||||||
|
case 3:
|
||||||
|
state = 1
|
||||||
|
goto lbl_2
|
||||||
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Begin()
|
println("1")
|
||||||
|
println("2")
|
||||||
|
state = 1
|
||||||
|
// return
|
||||||
|
|
||||||
println("test2222 yield:", 1)
|
lbl_1:
|
||||||
c.Yield(1)
|
println("3")
|
||||||
|
state = 2
|
||||||
|
// return
|
||||||
|
|
||||||
println("test2222 yield:", 2)
|
lbl_2:
|
||||||
c.Yield(2)
|
{
|
||||||
|
switch state {
|
||||||
|
case 1:
|
||||||
|
goto lbl_3
|
||||||
|
|
||||||
println("test2222 before yield none")
|
default:
|
||||||
c.YieldNone()
|
}
|
||||||
println("test2222 after yield none")
|
|
||||||
|
println("4")
|
||||||
|
state = 3
|
||||||
|
// return
|
||||||
|
|
||||||
|
lbl_3:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
301
main.go
301
main.go
@ -35,7 +35,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
genCogoFuncs(cwd)
|
genCogoFuncs(cwd)
|
||||||
genHasGenChecksOnOriginalFuncs(cwd)
|
// genHasGenChecksOnOriginalFuncs(cwd)
|
||||||
}
|
}
|
||||||
|
|
||||||
func genCogoFuncs(cwd string) {
|
func genCogoFuncs(cwd string) {
|
||||||
@ -54,10 +54,12 @@ func genCogoFuncs(cwd string) {
|
|||||||
p := &processor{
|
p := &processor{
|
||||||
fset: pkg.Fset,
|
fset: pkg.Fset,
|
||||||
funcDeclsToWrite: []*ast.FuncDecl{},
|
funcDeclsToWrite: []*ast.FuncDecl{},
|
||||||
|
BlockInfos: make([]BlockInfo, 0, 10),
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, synFile := range pkg.Syntax {
|
for i, synFile := range pkg.Syntax {
|
||||||
pkg.Syntax[i] = astutil.Apply(synFile, p.genCogoFuncsNodeProcessor, nil).(*ast.File)
|
pkg.Syntax[i] = astutil.Apply(synFile, p.nodeProcessor, nil).(*ast.File)
|
||||||
|
// pkg.Syntax[i] = astutil.Apply(synFile, p.genCogoFuncsNodeProcessor, nil).(*ast.File)
|
||||||
|
|
||||||
if len(p.funcDeclsToWrite) > 0 {
|
if len(p.funcDeclsToWrite) > 0 {
|
||||||
|
|
||||||
@ -101,6 +103,7 @@ func genHasGenChecksOnOriginalFuncs(cwd string) {
|
|||||||
p := &processor{
|
p := &processor{
|
||||||
fset: pkg.Fset,
|
fset: pkg.Fset,
|
||||||
funcDeclsToWrite: []*ast.FuncDecl{},
|
funcDeclsToWrite: []*ast.FuncDecl{},
|
||||||
|
BlockInfos: make([]BlockInfo, 0, 10),
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, synFile := range pkg.Syntax {
|
for i, synFile := range pkg.Syntax {
|
||||||
@ -118,9 +121,201 @@ func genHasGenChecksOnOriginalFuncs(cwd string) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type BlockInfo struct {
|
||||||
|
Switch *ast.SwitchStmt
|
||||||
|
BeforeBlockLblName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BlockInfo) addCase(caseStmts []ast.Stmt) int {
|
||||||
|
|
||||||
|
oldCaseCount := s.CaseCount()
|
||||||
|
newCase := getCaseWithStmts([]ast.Expr{
|
||||||
|
ast.NewIdent(fmt.Sprint(oldCaseCount + 1))},
|
||||||
|
caseStmts,
|
||||||
|
)
|
||||||
|
|
||||||
|
s.Switch.Body.List = append(s.Switch.Body.List, newCase)
|
||||||
|
return len(s.Switch.Body.List)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BlockInfo) CaseCount() int {
|
||||||
|
return len(s.Switch.Body.List)
|
||||||
|
}
|
||||||
|
|
||||||
type processor struct {
|
type processor struct {
|
||||||
fset *token.FileSet
|
fset *token.FileSet
|
||||||
funcDeclsToWrite []*ast.FuncDecl
|
funcDeclsToWrite []*ast.FuncDecl
|
||||||
|
BlockInfos []BlockInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *processor) nodeProcessor(c *astutil.Cursor) bool {
|
||||||
|
|
||||||
|
n := c.Node()
|
||||||
|
if n == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// If not a function or it's empty skip and continue reading the file AST
|
||||||
|
funcDecl, ok := n.(*ast.FuncDecl)
|
||||||
|
if !ok || funcDecl.Body == nil || len(funcDecl.Body.List) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if function has the required params
|
||||||
|
coroutineParamName := getCoroutineParamNameFromFuncDecl(funcDecl)
|
||||||
|
if coroutineParamName == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !blockUsesCogo(funcDecl.Body, coroutineParamName, true) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate code for function
|
||||||
|
p.processBlock(nil, funcDecl.Body, -1, coroutineParamName)
|
||||||
|
p.funcDeclsToWrite = append(p.funcDeclsToWrite, funcDecl)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *processor) processBlock(parentBlock, blockStmt *ast.BlockStmt, indexInParent int, coroutineParamName string) {
|
||||||
|
|
||||||
|
if !blockUsesCogo(blockStmt, coroutineParamName, true) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.pushNewBlock(parentBlock, blockStmt, indexInParent, &ast.SwitchStmt{
|
||||||
|
Tag: ast.NewIdent(coroutineParamName + ".State"),
|
||||||
|
Body: &ast.BlockStmt{
|
||||||
|
List: []ast.Stmt{},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
defer p.popSwitch()
|
||||||
|
|
||||||
|
for i := 0; i < len(blockStmt.List); i++ {
|
||||||
|
|
||||||
|
stmt := blockStmt.List[i]
|
||||||
|
|
||||||
|
if ifStmt, ifStmtOk := stmt.(*ast.IfStmt); ifStmtOk {
|
||||||
|
|
||||||
|
p.processBlock(blockStmt, ifStmt.Body, i, coroutineParamName)
|
||||||
|
|
||||||
|
} else if forStmt, forStmtOk := stmt.(*ast.ForStmt); forStmtOk {
|
||||||
|
|
||||||
|
// @TODO: For loops need unique handling to convert to an if statement (if they use cogo)
|
||||||
|
p.processBlock(blockStmt, forStmt.Body, i, coroutineParamName)
|
||||||
|
|
||||||
|
} else if bStmt, blockStmtOk := stmt.(*ast.BlockStmt); blockStmtOk {
|
||||||
|
|
||||||
|
p.processBlock(blockStmt, bStmt, i, coroutineParamName)
|
||||||
|
}
|
||||||
|
|
||||||
|
selExpr, selExprArgs := tryGetSelExprFromStmt(stmt, coroutineParamName, "Yield")
|
||||||
|
if selExpr == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
p.addYield(blockStmt, &i, selExprArgs, coroutineParamName)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *processor) pushNewBlock(parentBlock, blockStmt *ast.BlockStmt, indexInParent int, switchStmt *ast.SwitchStmt) {
|
||||||
|
|
||||||
|
beforeBlockLblName := getLblNameFromSubStateNum(int32(len(p.BlockInfos)+1), 0)
|
||||||
|
if parentBlock != nil {
|
||||||
|
parentBlock.List = insertIntoArr[ast.Stmt](parentBlock.List, indexInParent, &ast.LabeledStmt{
|
||||||
|
Label: ast.NewIdent(beforeBlockLblName),
|
||||||
|
Stmt: &ast.EmptyStmt{},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add switch to start of block
|
||||||
|
blockStmt.List = insertIntoArr[ast.Stmt](blockStmt.List, 0, switchStmt)
|
||||||
|
|
||||||
|
p.BlockInfos = append(p.BlockInfos, BlockInfo{
|
||||||
|
Switch: switchStmt,
|
||||||
|
BeforeBlockLblName: beforeBlockLblName,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *processor) currSwitch() *ast.SwitchStmt {
|
||||||
|
return p.BlockInfos[len(p.BlockInfos)-1].Switch
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *processor) currSwitchInfo() *BlockInfo {
|
||||||
|
return &p.BlockInfos[len(p.BlockInfos)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *processor) popSwitch() *ast.SwitchStmt {
|
||||||
|
|
||||||
|
s := p.BlockInfos[len(p.BlockInfos)-1]
|
||||||
|
p.BlockInfos = p.BlockInfos[:len(p.BlockInfos)-1]
|
||||||
|
return s.Switch
|
||||||
|
}
|
||||||
|
|
||||||
|
func toStr[T any](x T) string {
|
||||||
|
return fmt.Sprintf("%+v", x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *processor) addYield(block *ast.BlockStmt, listIndex *int, yieldArgs []ast.Expr, coroutineParamName string) {
|
||||||
|
|
||||||
|
// Update switch statements and find the new state value
|
||||||
|
newCaseCondition := p.currSwitchInfo().CaseCount() + 1
|
||||||
|
newLblName := getLblNameFromSubStateNum(int32(len(p.BlockInfos)), int32(newCaseCondition))
|
||||||
|
|
||||||
|
// Go over all switch cases from current inner block to outer most
|
||||||
|
// and add a new case in each to make us reach this new yield
|
||||||
|
p.currSwitchInfo().addCase([]ast.Stmt{
|
||||||
|
&ast.BranchStmt{
|
||||||
|
Tok: token.GOTO,
|
||||||
|
Label: ast.NewIdent(newLblName),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
stateValNeededForNexSwitch := newCaseCondition
|
||||||
|
prevBlockBeforeLblName := p.currSwitchInfo().BeforeBlockLblName
|
||||||
|
for i := len(p.BlockInfos) - 2; i >= 0; i-- {
|
||||||
|
|
||||||
|
info := &p.BlockInfos[i]
|
||||||
|
info.addCase([]ast.Stmt{
|
||||||
|
&ast.AssignStmt{
|
||||||
|
Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".State")},
|
||||||
|
Tok: token.ASSIGN,
|
||||||
|
Rhs: []ast.Expr{ast.NewIdent(toStr(stateValNeededForNexSwitch))},
|
||||||
|
},
|
||||||
|
&ast.BranchStmt{
|
||||||
|
Tok: token.GOTO,
|
||||||
|
Label: ast.NewIdent(prevBlockBeforeLblName),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
stateValNeededForNexSwitch = info.CaseCount()
|
||||||
|
prevBlockBeforeLblName = info.BeforeBlockLblName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and add yield block
|
||||||
|
newBlock := &ast.BlockStmt{
|
||||||
|
List: []ast.Stmt{
|
||||||
|
&ast.AssignStmt{
|
||||||
|
Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".State")},
|
||||||
|
Tok: token.ASSIGN,
|
||||||
|
Rhs: []ast.Expr{ast.NewIdent(toStr(stateValNeededForNexSwitch))},
|
||||||
|
},
|
||||||
|
&ast.AssignStmt{
|
||||||
|
Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".Out")},
|
||||||
|
Tok: token.ASSIGN,
|
||||||
|
Rhs: yieldArgs,
|
||||||
|
},
|
||||||
|
&ast.ReturnStmt{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
block.List[*listIndex] = newBlock
|
||||||
|
block.List = insertIntoArr[ast.Stmt](block.List, *listIndex+1, &ast.LabeledStmt{
|
||||||
|
Label: ast.NewIdent(newLblName),
|
||||||
|
Stmt: &ast.EmptyStmt{},
|
||||||
|
})
|
||||||
|
*listIndex++
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
|
func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
|
||||||
@ -140,8 +335,8 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
hasBegin := funcHasSelInBody(funcDecl, coroutineParamName, "Begin")
|
hasBegin := blockHasOneOrMoreSels(funcDecl.Body, []SelExprInfo{{coroutineParamName, "Begin"}}, false)
|
||||||
hasYield := funcHasSelInBody(funcDecl, coroutineParamName, "Yield")
|
hasYield := blockHasOneOrMoreSels(funcDecl.Body, []SelExprInfo{{coroutineParamName, "Yield"}}, false)
|
||||||
if !hasBegin && !hasYield {
|
if !hasBegin && !hasYield {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -152,7 +347,7 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
|
|||||||
|
|
||||||
beginBodyListIndex := -1
|
beginBodyListIndex := -1
|
||||||
lastCaseEndBodyListIndex := -1
|
lastCaseEndBodyListIndex := -1
|
||||||
switchStmt := &ast.SwitchStmt{
|
mainSwitchStmt := &ast.SwitchStmt{
|
||||||
Tag: ast.NewIdent(coroutineParamName + ".State"),
|
Tag: ast.NewIdent(coroutineParamName + ".State"),
|
||||||
Body: &ast.BlockStmt{
|
Body: &ast.BlockStmt{
|
||||||
List: []ast.Stmt{},
|
List: []ast.Stmt{},
|
||||||
@ -172,8 +367,8 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
|
|||||||
|
|
||||||
stmt := funcDecl.Body.List[i]
|
stmt := funcDecl.Body.List[i]
|
||||||
|
|
||||||
var cogoFuncSelExpr *ast.SelectorExpr
|
|
||||||
var blockStmt *ast.BlockStmt
|
var blockStmt *ast.BlockStmt
|
||||||
|
var selExpr *ast.SelectorExpr
|
||||||
|
|
||||||
if ifStmt, ifStmtOk := stmt.(*ast.IfStmt); ifStmtOk {
|
if ifStmt, ifStmtOk := stmt.(*ast.IfStmt); ifStmtOk {
|
||||||
|
|
||||||
@ -186,7 +381,7 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
|
|||||||
|
|
||||||
} else if forStmt, forStmtOk := stmt.(*ast.ForStmt); forStmtOk {
|
} else if forStmt, forStmtOk := stmt.(*ast.ForStmt); forStmtOk {
|
||||||
|
|
||||||
outInitBlock, postInitStmt, outIfStmt, subStateNums := p.genCogoForStmt(forStmt, coroutineParamName, len(switchStmt.Body.List))
|
outInitBlock, postInitStmt, outIfStmt, subStateNums := p.genCogoForStmt(forStmt, coroutineParamName, len(mainSwitchStmt.Body.List))
|
||||||
|
|
||||||
if len(subStateNums) > 1 {
|
if len(subStateNums) > 1 {
|
||||||
panic("For loops currently don't support more than one yield")
|
panic("For loops currently don't support more than one yield")
|
||||||
@ -202,7 +397,7 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
|
|||||||
postInitStmt,
|
postInitStmt,
|
||||||
&ast.BranchStmt{
|
&ast.BranchStmt{
|
||||||
Tok: token.GOTO,
|
Tok: token.GOTO,
|
||||||
Label: ast.NewIdent(getLblNameFromSubStateNum(subStateNums[0])),
|
Label: ast.NewIdent(getLblNameFromSubStateNum(1, subStateNums[0])),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
@ -213,7 +408,7 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
|
|||||||
|
|
||||||
// // Insert lable after initial condition
|
// // Insert lable after initial condition
|
||||||
var stmtInterface ast.Stmt = &ast.LabeledStmt{
|
var stmtInterface ast.Stmt = &ast.LabeledStmt{
|
||||||
Label: ast.NewIdent(getLblNameFromSubStateNum(subStateNums[0])),
|
Label: ast.NewIdent(getLblNameFromSubStateNum(1, subStateNums[0])),
|
||||||
Stmt: &ast.EmptyStmt{},
|
Stmt: &ast.EmptyStmt{},
|
||||||
}
|
}
|
||||||
funcDecl.Body.List = insertIntoArr(funcDecl.Body.List, i+1, stmtInterface)
|
funcDecl.Body.List = insertIntoArr(funcDecl.Body.List, i+1, stmtInterface)
|
||||||
@ -231,7 +426,7 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
|
|||||||
|
|
||||||
if blockStmt != nil {
|
if blockStmt != nil {
|
||||||
|
|
||||||
subStateNums := p.genCogoBlockStmt(blockStmt, coroutineParamName, len(switchStmt.Body.List))
|
subStateNums := p.genCogoBlockStmt(blockStmt, coroutineParamName, len(mainSwitchStmt.Body.List))
|
||||||
for _, subStateNum := range subStateNums {
|
for _, subStateNum := range subStateNums {
|
||||||
|
|
||||||
subSwitchStmt.Body.List = append(subSwitchStmt.Body.List,
|
subSwitchStmt.Body.List = append(subSwitchStmt.Body.List,
|
||||||
@ -240,14 +435,14 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
|
|||||||
[]ast.Stmt{
|
[]ast.Stmt{
|
||||||
&ast.BranchStmt{
|
&ast.BranchStmt{
|
||||||
Tok: token.GOTO,
|
Tok: token.GOTO,
|
||||||
Label: ast.NewIdent(getLblNameFromSubStateNum(subStateNum)),
|
Label: ast.NewIdent(getLblNameFromSubStateNum(1, subStateNum)),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
var stmtToInsert ast.Stmt = &ast.LabeledStmt{
|
var stmtToInsert ast.Stmt = &ast.LabeledStmt{
|
||||||
Label: ast.NewIdent(getLblNameFromSubStateNum(subStateNum)),
|
Label: ast.NewIdent(getLblNameFromSubStateNum(1, subStateNum)),
|
||||||
Stmt: &ast.EmptyStmt{},
|
Stmt: &ast.EmptyStmt{},
|
||||||
}
|
}
|
||||||
funcDecl.Body.List = insertIntoArr(funcDecl.Body.List, i+1, stmtToInsert)
|
funcDecl.Body.List = insertIntoArr(funcDecl.Body.List, i+1, stmtToInsert)
|
||||||
@ -257,33 +452,33 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find functions calls in the style of 'xyz.ABC123()'
|
// Find cogo function call in the style of 'cogo.Xyz()'
|
||||||
exprStmt, exprStmtOk := stmt.(*ast.ExprStmt)
|
exprStmt, ok := stmt.(*ast.ExprStmt)
|
||||||
if !exprStmtOk {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
callExpr, exprStmtOk := exprStmt.X.(*ast.CallExpr)
|
callExpr, ok := exprStmt.X.(*ast.CallExpr)
|
||||||
if !exprStmtOk {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
cogoFuncSelExpr, exprStmtOk = callExpr.Fun.(*ast.SelectorExpr)
|
selExpr, ok = callExpr.Fun.(*ast.SelectorExpr)
|
||||||
if !exprStmtOk {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if !selExprHasLhsName(cogoFuncSelExpr, coroutineParamName) {
|
if !selExprHasLhsName(selExpr, coroutineParamName) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now that we found a call to cogo decide what to do
|
// Now that we found a call to cogo decide what to do
|
||||||
if cogoFuncSelExpr.Sel.Name == "Begin" {
|
if selExpr.Sel.Name == "Begin" {
|
||||||
|
|
||||||
beginBodyListIndex = i
|
beginBodyListIndex = i
|
||||||
lastCaseEndBodyListIndex = i
|
lastCaseEndBodyListIndex = i
|
||||||
continue
|
continue
|
||||||
} else if cogoFuncSelExpr.Sel.Name == "Yield" {
|
} else if selExpr.Sel.Name == "Yield" {
|
||||||
|
|
||||||
// Add everything from the last begin/yield until this yield into a case
|
// Add everything from the last begin/yield until this yield into a case
|
||||||
stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i]
|
stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i]
|
||||||
@ -319,16 +514,16 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
|
|||||||
&ast.ReturnStmt{},
|
&ast.ReturnStmt{},
|
||||||
)
|
)
|
||||||
|
|
||||||
switchStmt.Body.List = append(switchStmt.Body.List,
|
mainSwitchStmt.Body.List = append(mainSwitchStmt.Body.List,
|
||||||
getCaseWithStmts(
|
getCaseWithStmts(
|
||||||
[]ast.Expr{ast.NewIdent(fmt.Sprint(len(switchStmt.Body.List)))},
|
[]ast.Expr{ast.NewIdent(fmt.Sprint(len(mainSwitchStmt.Body.List)))},
|
||||||
caseStmts,
|
caseStmts,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
lastCaseEndBodyListIndex = i
|
lastCaseEndBodyListIndex = i
|
||||||
|
|
||||||
} else if cogoFuncSelExpr.Sel.Name == "YieldTo" {
|
} else if selExpr.Sel.Name == "YieldTo" {
|
||||||
|
|
||||||
// Add everything from the last begin/yield until this yield into a case
|
// Add everything from the last begin/yield until this yield into a case
|
||||||
stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i]
|
stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i]
|
||||||
@ -364,15 +559,15 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
|
|||||||
&ast.ReturnStmt{},
|
&ast.ReturnStmt{},
|
||||||
)
|
)
|
||||||
|
|
||||||
switchStmt.Body.List = append(switchStmt.Body.List,
|
mainSwitchStmt.Body.List = append(mainSwitchStmt.Body.List,
|
||||||
getCaseWithStmts(
|
getCaseWithStmts(
|
||||||
[]ast.Expr{ast.NewIdent(fmt.Sprint(len(switchStmt.Body.List)))},
|
[]ast.Expr{ast.NewIdent(fmt.Sprint(len(mainSwitchStmt.Body.List)))},
|
||||||
caseStmts,
|
caseStmts,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
lastCaseEndBodyListIndex = i
|
lastCaseEndBodyListIndex = i
|
||||||
} else if cogoFuncSelExpr.Sel.Name == "YieldNone" {
|
} else if selExpr.Sel.Name == "YieldNone" {
|
||||||
|
|
||||||
// Add everything from the last begin/yield until this yield into a case
|
// Add everything from the last begin/yield until this yield into a case
|
||||||
stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i]
|
stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i]
|
||||||
@ -403,9 +598,9 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
|
|||||||
&ast.ReturnStmt{},
|
&ast.ReturnStmt{},
|
||||||
)
|
)
|
||||||
|
|
||||||
switchStmt.Body.List = append(switchStmt.Body.List,
|
mainSwitchStmt.Body.List = append(mainSwitchStmt.Body.List,
|
||||||
getCaseWithStmts(
|
getCaseWithStmts(
|
||||||
[]ast.Expr{ast.NewIdent(fmt.Sprint(len(switchStmt.Body.List)))},
|
[]ast.Expr{ast.NewIdent(fmt.Sprint(len(mainSwitchStmt.Body.List)))},
|
||||||
caseStmts,
|
caseStmts,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -443,9 +638,9 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
|
|||||||
)
|
)
|
||||||
caseStmts = append(caseStmts, stmtsToEndOfFunc...)
|
caseStmts = append(caseStmts, stmtsToEndOfFunc...)
|
||||||
|
|
||||||
switchStmt.Body.List = append(switchStmt.Body.List,
|
mainSwitchStmt.Body.List = append(mainSwitchStmt.Body.List,
|
||||||
getCaseWithStmts(
|
getCaseWithStmts(
|
||||||
[]ast.Expr{ast.NewIdent(fmt.Sprint(len(switchStmt.Body.List)))},
|
[]ast.Expr{ast.NewIdent(fmt.Sprint(len(mainSwitchStmt.Body.List)))},
|
||||||
caseStmts,
|
caseStmts,
|
||||||
),
|
),
|
||||||
|
|
||||||
@ -471,7 +666,7 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
|
|||||||
// Apply changes
|
// Apply changes
|
||||||
funcDecl.Body.List = funcDecl.Body.List[:beginBodyListIndex]
|
funcDecl.Body.List = funcDecl.Body.List[:beginBodyListIndex]
|
||||||
funcDecl.Body.List = append(funcDecl.Body.List,
|
funcDecl.Body.List = append(funcDecl.Body.List,
|
||||||
switchStmt,
|
mainSwitchStmt,
|
||||||
)
|
)
|
||||||
|
|
||||||
originalList := funcDecl.Body.List
|
originalList := funcDecl.Body.List
|
||||||
@ -586,8 +781,8 @@ func (p *processor) genCogoBlockStmt(blockStmt *ast.BlockStmt, coroutineParamNam
|
|||||||
return subStateNums
|
return subStateNums
|
||||||
}
|
}
|
||||||
|
|
||||||
func getLblNameFromSubStateNum(subStateNum int32) string {
|
func getLblNameFromSubStateNum(switchNum, caseNum int32) string {
|
||||||
return fmt.Sprint("cogo_", subStateNum)
|
return fmt.Sprintf("cogo_%d_%d", switchNum, caseNum)
|
||||||
}
|
}
|
||||||
|
|
||||||
func insertIntoArr[T any](a []T, index int, value T) []T {
|
func insertIntoArr[T any](a []T, index int, value T) []T {
|
||||||
@ -639,7 +834,7 @@ func (p *processor) genHasGenChecksOnOriginalFuncsNodeProcessor(c *astutil.Curso
|
|||||||
}
|
}
|
||||||
|
|
||||||
coroutineParamName := getCoroutineParamNameFromFuncDecl(funcDecl)
|
coroutineParamName := getCoroutineParamNameFromFuncDecl(funcDecl)
|
||||||
if coroutineParamName == "" || !funcHasSelInBody(funcDecl, coroutineParamName, "Begin") {
|
if coroutineParamName == "" || !blockHasOneOrMoreSels(funcDecl.Body, []SelExprInfo{{coroutineParamName, "Begin"}}, false) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -707,13 +902,33 @@ func createStmtFromSelFuncCall(lhs, rhs string) ast.Stmt {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func funcHasSelInBody(fd *ast.FuncDecl, selLhs, selRhs string) bool {
|
type SelExprInfo struct {
|
||||||
|
// Give `cogo.Yield()`, Lhs would be 'cogo'
|
||||||
|
Lhs string
|
||||||
|
// Give `cogo.Yield()`, Rhs would be 'Yield'
|
||||||
|
Rhs string
|
||||||
|
}
|
||||||
|
|
||||||
if fd.Body == nil || len(fd.Body.List) == 0 {
|
func blockUsesCogo(block *ast.BlockStmt, coroutineParamName string, checkChildBlocks bool) bool {
|
||||||
|
return blockHasOneOrMoreSels(block, []SelExprInfo{{coroutineParamName, "Yield"}, {coroutineParamName, "YieldTo"}, {coroutineParamName, "YieldNone"}}, checkChildBlocks)
|
||||||
|
}
|
||||||
|
|
||||||
|
func blockHasOneOrMoreSels(block *ast.BlockStmt, sels []SelExprInfo, checkChildBlocks bool) bool {
|
||||||
|
|
||||||
|
if block == nil || len(block.List) == 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, stmt := range fd.Body.List {
|
for _, stmt := range block.List {
|
||||||
|
|
||||||
|
// Check recursively if requested
|
||||||
|
if checkChildBlocks {
|
||||||
|
if blockStmt, ok := stmt.(*ast.BlockStmt); ok {
|
||||||
|
if blockHasOneOrMoreSels(blockStmt, sels, true) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Find functions calls in the style of 'cogo.ABC123()'
|
// Find functions calls in the style of 'cogo.ABC123()'
|
||||||
exprStmt, ok := stmt.(*ast.ExprStmt)
|
exprStmt, ok := stmt.(*ast.ExprStmt)
|
||||||
@ -731,10 +946,12 @@ func funcHasSelInBody(fd *ast.FuncDecl, selLhs, selRhs string) bool {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if selExprIs(selExpr, selLhs, selRhs) {
|
for _, v := range sels {
|
||||||
|
if selExprIs(selExpr, v.Lhs, v.Rhs) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -789,10 +1006,10 @@ func selExprIs(selExpr *ast.SelectorExpr, lhs, rhs string) bool {
|
|||||||
return pkgIdentExpr.Name == lhs && selExpr.Sel.Name == rhs
|
return pkgIdentExpr.Name == lhs && selExpr.Sel.Name == rhs
|
||||||
}
|
}
|
||||||
|
|
||||||
func getCaseWithStmts(caseConditions []ast.Expr, stmts []ast.Stmt) *ast.CaseClause {
|
func getCaseWithStmts(caseConditions []ast.Expr, bodyStmts []ast.Stmt) *ast.CaseClause {
|
||||||
return &ast.CaseClause{
|
return &ast.CaseClause{
|
||||||
List: caseConditions,
|
List: caseConditions,
|
||||||
Body: stmts,
|
Body: bodyStmts,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user