Support one level nesting of yield in ifStmts

This commit is contained in:
bloeys
2022-10-28 01:04:47 +04:00
parent eeea31092b
commit e0cc0c3de4
4 changed files with 212 additions and 18 deletions

View File

@ -1,11 +1,14 @@
package cogo
import "fmt"
type CoroutineFunc[InT, OutT any] func(c *Coroutine[InT, OutT]) (out OutT)
type Coroutine[InT, OutT any] struct {
State int32
In InT
Func CoroutineFunc[InT, OutT]
State int32
SubState int32
In InT
Func CoroutineFunc[InT, OutT]
}
func (c *Coroutine[InT, OutT]) Begin() {
@ -22,6 +25,7 @@ func (c *Coroutine[InT, OutT]) Tick() (out OutT, done bool) {
}
func (c *Coroutine[InT, OutT]) Yield(out OutT) {
panic(fmt.Sprintf("Yield got called at runtime, which means the code generator was not run, you used cogo incorrectly, or cogo has a bug. Yield should NOT get called at runtime. coroutine: %+v;;; yield value: %+v;;;", c, out))
}
func HasGen() bool {

View File

@ -6,6 +6,7 @@ import (
"go/format"
"go/token"
"io"
"math/rand"
"os"
"strings"
@ -166,13 +167,49 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
},
}
subSwitchStmt := &ast.SwitchStmt{
Tag: ast.NewIdent(coroutineParamName + ".SubState"),
Body: &ast.BlockStmt{
List: []ast.Stmt{
getCaseWithStmts(nil, []ast.Stmt{}),
},
},
}
for i, stmt := range funcDecl.Body.List {
var cogoFuncSelExpr *ast.SelectorExpr
ifStmt, ok := stmt.(*ast.IfStmt)
if ok && ifStmtIsHasGen(ifStmt) {
funcDecl.Body.List[i] = &ast.EmptyStmt{}
ifStmt, ifStmtOk := stmt.(*ast.IfStmt)
if ifStmtOk {
if ifStmtIsHasGen(ifStmt) {
funcDecl.Body.List[i] = &ast.EmptyStmt{}
continue
}
subStateNums := p.genCogoIfStmt(ifStmt, coroutineParamName, len(switchStmt.Body.List))
for _, subStateNum := range subStateNums {
subSwitchStmt.Body.List = append(subSwitchStmt.Body.List,
getCaseWithStmts(
[]ast.Expr{ast.NewIdent(fmt.Sprint(subStateNum))},
[]ast.Stmt{
&ast.BranchStmt{
Tok: token.GOTO,
Label: ast.NewIdent(getLblNameFromSubStateNum(subStateNum)),
},
},
),
)
var stmtToInsert ast.Stmt = &ast.LabeledStmt{
Label: ast.NewIdent(getLblNameFromSubStateNum(subStateNum)),
Stmt: &ast.EmptyStmt{},
}
funcDecl.Body.List = insertIntoArr(funcDecl.Body.List, i+1, stmtToInsert)
}
continue
}
@ -210,12 +247,28 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i]
caseStmts := make([]ast.Stmt, 0, len(stmtsSinceLastCogo)+2)
caseStmts = append(caseStmts, subSwitchStmt)
subSwitchStmt = &ast.SwitchStmt{
Tag: ast.NewIdent(coroutineParamName + ".SubState"),
Body: &ast.BlockStmt{
List: []ast.Stmt{
getCaseWithStmts(nil, []ast.Stmt{}),
},
},
}
caseStmts = append(caseStmts, stmtsSinceLastCogo...)
caseStmts = append(caseStmts,
&ast.IncDecStmt{
Tok: token.INC,
X: ast.NewIdent(coroutineParamName + ".State"),
},
&ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".SubState")},
Tok: token.ASSIGN,
Rhs: []ast.Expr{ast.NewIdent("-1")},
},
&ast.ReturnStmt{
Results: callExpr.Args,
},
@ -237,12 +290,28 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
stmtsToEndOfFunc := funcDecl.Body.List[lastCaseEndBodyListIndex+1:]
caseStmts := make([]ast.Stmt, 0, len(stmtsToEndOfFunc)+1)
caseStmts = append(caseStmts, subSwitchStmt)
subSwitchStmt = &ast.SwitchStmt{
Tag: ast.NewIdent(coroutineParamName + ".SubState"),
Body: &ast.BlockStmt{
List: []ast.Stmt{
getCaseWithStmts(nil, []ast.Stmt{}),
},
},
}
caseStmts = append(caseStmts,
&ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".State")},
Tok: token.ASSIGN,
Rhs: []ast.Expr{ast.NewIdent("-1")},
},
&ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".SubState")},
Tok: token.ASSIGN,
Rhs: []ast.Expr{ast.NewIdent("-1")},
},
)
caseStmts = append(caseStmts, stmtsToEndOfFunc...)
@ -261,6 +330,11 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
Tok: token.ASSIGN,
Rhs: []ast.Expr{ast.NewIdent("-1")},
},
&ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".SubState")},
Tok: token.ASSIGN,
Rhs: []ast.Expr{ast.NewIdent("-1")},
},
&ast.ReturnStmt{},
},
),
@ -282,6 +356,81 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
return true
}
func (p *processor) genCogoIfStmt(ifStmt *ast.IfStmt, coroutineParamName string, currCase int) (subStateNums []int32) {
for i, stmt := range ifStmt.Body.List {
selExpr, selExprArgs := tryGetSelExprFromStmt(stmt, coroutineParamName, "Yield")
if selExpr == nil {
continue
}
// @TODO: Ensure that subStateNums don't get reused
// subStateNum >= 1000_000
newSubStateNum := rand.Int31() + 1000_000
subStateNums = append(subStateNums, newSubStateNum)
ifStmt.Body.List[i] = &ast.BlockStmt{
List: []ast.Stmt{
&ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".State")},
Tok: token.ASSIGN,
Rhs: []ast.Expr{ast.NewIdent(fmt.Sprint(currCase))},
},
&ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".SubState")},
Tok: token.ASSIGN,
Rhs: []ast.Expr{ast.NewIdent(fmt.Sprint(newSubStateNum))},
},
&ast.ReturnStmt{
Results: selExprArgs,
},
},
}
}
return subStateNums
}
func getLblNameFromSubStateNum(subStateNum int32) string {
return fmt.Sprint("cogo_", subStateNum)
}
func insertIntoArr[T any](a []T, index int, value T) []T {
if len(a) == index {
return append(a, value)
}
a = append(a[:index+1], a[index:]...)
a[index] = value
return a
}
func tryGetSelExprFromStmt(stmt ast.Stmt, lhs, rhs string) (selExpr *ast.SelectorExpr, args []ast.Expr) {
exprStmt, ok := stmt.(*ast.ExprStmt)
if !ok {
return nil, nil
}
callExpr, ok := exprStmt.X.(*ast.CallExpr)
if !ok {
return nil, nil
}
args = callExpr.Args
selExpr, ok = callExpr.Fun.(*ast.SelectorExpr)
if !ok {
return nil, nil
}
if selExprIs(selExpr, lhs, rhs) {
return selExpr, args
}
return nil, nil
}
func (p *processor) genHasGenChecksOnOriginalFuncsNodeProcessor(c *astutil.Cursor) bool {
n := c.Node()
@ -434,14 +583,14 @@ func funcCallHasLhsName(selExpr *ast.SelectorExpr, pkgName string) bool {
return ok && pkgIdent.Name == pkgName
}
func selExprIs(selExpr *ast.SelectorExpr, pkgName, typeName string) bool {
func selExprIs(selExpr *ast.SelectorExpr, lhs, rhs string) bool {
pkgIdentExpr, ok := selExpr.X.(*ast.Ident)
if !ok {
return false
}
return pkgIdentExpr.Name == pkgName && selExpr.Sel.Name == typeName
return pkgIdentExpr.Name == lhs && selExpr.Sel.Name == rhs
}
// func filter[T any](arr []T, where func(x T) bool) []T {
@ -485,6 +634,7 @@ func writeAst(fName, topComment string, fset *token.FileSet, node any) {
b, err := imports.Process(fName, nil, nil)
if err != nil {
format.Node(os.Stdout, fset, node)
panic("Failed to process imports on file " + fName + ". Err: " + err.Error())
}

View File

@ -6,33 +6,67 @@ import "github.com/bloeys/cogo/cogo"
func test_cogo(c *cogo.Coroutine[int, int]) (out int) {
switch c.State {
case 0:
switch c.SubState {
default:
}
println("Tick 1")
println("\nTick 1")
c.State++
c.SubState = -1
return 1
case 1:
switch c.SubState {
default:
case 1299498081:
goto cogo_1299498081
}
println("Tick 2")
if c.In > 1 {
println("\nTick 1.5")
{
c.State = 1
c.SubState = 1299498081
return c.In
}
}
cogo_1299498081:
;
println("\nTick 2")
c.State++
c.SubState = -1
return 2
case 2:
switch c.SubState {
default:
}
println("Tick 3")
println("\nTick 3")
c.State++
c.SubState = -1
return 3
case 3:
switch c.SubState {
default:
}
println("Tick 4")
println("\nTick 4")
c.State++
c.SubState = -1
return 4
case 4:
switch c.SubState {
default:
}
c.State = -1
c.SubState = -1
println("Tick before end")
println("\nTick before end")
return out
default:
c.State = -1
c.SubState = -1
return
}
}

16
main.go
View File

@ -16,19 +16,24 @@ func test(c *cogo.Coroutine[int, int]) (out int) {
c.Begin()
println("Tick 1")
println("\nTick 1")
c.Yield(1)
println("Tick 2")
if c.In > 1 {
println("\nTick 1.5")
c.Yield(c.In)
}
println("\nTick 2")
c.Yield(2)
println("Tick 3")
println("\nTick 3")
c.Yield(3)
println("Tick 4")
println("\nTick 4")
c.Yield(4)
println("Tick before end")
println("\nTick before end")
return out
}
@ -40,6 +45,7 @@ func main() {
In: 0,
}
c.In = 5
for out, done := c.Tick(); !done; out, done = c.Tick() {
println(out)
}