mirror of
https://github.com/bloeys/cogo.git
synced 2025-12-29 08:58:19 +00:00
Support one level nesting of yield in ifStmts
This commit is contained in:
@ -1,9 +1,12 @@
|
||||
package cogo
|
||||
|
||||
import "fmt"
|
||||
|
||||
type CoroutineFunc[InT, OutT any] func(c *Coroutine[InT, OutT]) (out OutT)
|
||||
|
||||
type Coroutine[InT, OutT any] struct {
|
||||
State int32
|
||||
SubState int32
|
||||
In InT
|
||||
Func CoroutineFunc[InT, OutT]
|
||||
}
|
||||
@ -22,6 +25,7 @@ func (c *Coroutine[InT, OutT]) Tick() (out OutT, done bool) {
|
||||
}
|
||||
|
||||
func (c *Coroutine[InT, OutT]) Yield(out OutT) {
|
||||
panic(fmt.Sprintf("Yield got called at runtime, which means the code generator was not run, you used cogo incorrectly, or cogo has a bug. Yield should NOT get called at runtime. coroutine: %+v;;; yield value: %+v;;;", c, out))
|
||||
}
|
||||
|
||||
func HasGen() bool {
|
||||
|
||||
158
inliner/main.go
158
inliner/main.go
@ -6,6 +6,7 @@ import (
|
||||
"go/format"
|
||||
"go/token"
|
||||
"io"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
@ -166,16 +167,52 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
|
||||
},
|
||||
}
|
||||
|
||||
subSwitchStmt := &ast.SwitchStmt{
|
||||
Tag: ast.NewIdent(coroutineParamName + ".SubState"),
|
||||
Body: &ast.BlockStmt{
|
||||
List: []ast.Stmt{
|
||||
getCaseWithStmts(nil, []ast.Stmt{}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, stmt := range funcDecl.Body.List {
|
||||
|
||||
var cogoFuncSelExpr *ast.SelectorExpr
|
||||
|
||||
ifStmt, ok := stmt.(*ast.IfStmt)
|
||||
if ok && ifStmtIsHasGen(ifStmt) {
|
||||
ifStmt, ifStmtOk := stmt.(*ast.IfStmt)
|
||||
if ifStmtOk {
|
||||
|
||||
if ifStmtIsHasGen(ifStmt) {
|
||||
funcDecl.Body.List[i] = &ast.EmptyStmt{}
|
||||
continue
|
||||
}
|
||||
|
||||
subStateNums := p.genCogoIfStmt(ifStmt, coroutineParamName, len(switchStmt.Body.List))
|
||||
for _, subStateNum := range subStateNums {
|
||||
|
||||
subSwitchStmt.Body.List = append(subSwitchStmt.Body.List,
|
||||
getCaseWithStmts(
|
||||
[]ast.Expr{ast.NewIdent(fmt.Sprint(subStateNum))},
|
||||
[]ast.Stmt{
|
||||
&ast.BranchStmt{
|
||||
Tok: token.GOTO,
|
||||
Label: ast.NewIdent(getLblNameFromSubStateNum(subStateNum)),
|
||||
},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
var stmtToInsert ast.Stmt = &ast.LabeledStmt{
|
||||
Label: ast.NewIdent(getLblNameFromSubStateNum(subStateNum)),
|
||||
Stmt: &ast.EmptyStmt{},
|
||||
}
|
||||
funcDecl.Body.List = insertIntoArr(funcDecl.Body.List, i+1, stmtToInsert)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Find functions calls in the style of 'xyz.ABC123()'
|
||||
exprStmt, exprStmtOk := stmt.(*ast.ExprStmt)
|
||||
if !exprStmtOk {
|
||||
@ -210,12 +247,28 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
|
||||
stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i]
|
||||
|
||||
caseStmts := make([]ast.Stmt, 0, len(stmtsSinceLastCogo)+2)
|
||||
|
||||
caseStmts = append(caseStmts, subSwitchStmt)
|
||||
subSwitchStmt = &ast.SwitchStmt{
|
||||
Tag: ast.NewIdent(coroutineParamName + ".SubState"),
|
||||
Body: &ast.BlockStmt{
|
||||
List: []ast.Stmt{
|
||||
getCaseWithStmts(nil, []ast.Stmt{}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
caseStmts = append(caseStmts, stmtsSinceLastCogo...)
|
||||
caseStmts = append(caseStmts,
|
||||
&ast.IncDecStmt{
|
||||
Tok: token.INC,
|
||||
X: ast.NewIdent(coroutineParamName + ".State"),
|
||||
},
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".SubState")},
|
||||
Tok: token.ASSIGN,
|
||||
Rhs: []ast.Expr{ast.NewIdent("-1")},
|
||||
},
|
||||
&ast.ReturnStmt{
|
||||
Results: callExpr.Args,
|
||||
},
|
||||
@ -237,12 +290,28 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
|
||||
stmtsToEndOfFunc := funcDecl.Body.List[lastCaseEndBodyListIndex+1:]
|
||||
|
||||
caseStmts := make([]ast.Stmt, 0, len(stmtsToEndOfFunc)+1)
|
||||
|
||||
caseStmts = append(caseStmts, subSwitchStmt)
|
||||
subSwitchStmt = &ast.SwitchStmt{
|
||||
Tag: ast.NewIdent(coroutineParamName + ".SubState"),
|
||||
Body: &ast.BlockStmt{
|
||||
List: []ast.Stmt{
|
||||
getCaseWithStmts(nil, []ast.Stmt{}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
caseStmts = append(caseStmts,
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".State")},
|
||||
Tok: token.ASSIGN,
|
||||
Rhs: []ast.Expr{ast.NewIdent("-1")},
|
||||
},
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".SubState")},
|
||||
Tok: token.ASSIGN,
|
||||
Rhs: []ast.Expr{ast.NewIdent("-1")},
|
||||
},
|
||||
)
|
||||
caseStmts = append(caseStmts, stmtsToEndOfFunc...)
|
||||
|
||||
@ -261,6 +330,11 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
|
||||
Tok: token.ASSIGN,
|
||||
Rhs: []ast.Expr{ast.NewIdent("-1")},
|
||||
},
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".SubState")},
|
||||
Tok: token.ASSIGN,
|
||||
Rhs: []ast.Expr{ast.NewIdent("-1")},
|
||||
},
|
||||
&ast.ReturnStmt{},
|
||||
},
|
||||
),
|
||||
@ -282,6 +356,81 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *processor) genCogoIfStmt(ifStmt *ast.IfStmt, coroutineParamName string, currCase int) (subStateNums []int32) {
|
||||
|
||||
for i, stmt := range ifStmt.Body.List {
|
||||
|
||||
selExpr, selExprArgs := tryGetSelExprFromStmt(stmt, coroutineParamName, "Yield")
|
||||
if selExpr == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// @TODO: Ensure that subStateNums don't get reused
|
||||
// subStateNum >= 1000_000
|
||||
newSubStateNum := rand.Int31() + 1000_000
|
||||
subStateNums = append(subStateNums, newSubStateNum)
|
||||
ifStmt.Body.List[i] = &ast.BlockStmt{
|
||||
List: []ast.Stmt{
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".State")},
|
||||
Tok: token.ASSIGN,
|
||||
Rhs: []ast.Expr{ast.NewIdent(fmt.Sprint(currCase))},
|
||||
},
|
||||
&ast.AssignStmt{
|
||||
Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".SubState")},
|
||||
Tok: token.ASSIGN,
|
||||
Rhs: []ast.Expr{ast.NewIdent(fmt.Sprint(newSubStateNum))},
|
||||
},
|
||||
&ast.ReturnStmt{
|
||||
Results: selExprArgs,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return subStateNums
|
||||
}
|
||||
|
||||
func getLblNameFromSubStateNum(subStateNum int32) string {
|
||||
return fmt.Sprint("cogo_", subStateNum)
|
||||
}
|
||||
|
||||
func insertIntoArr[T any](a []T, index int, value T) []T {
|
||||
|
||||
if len(a) == index {
|
||||
return append(a, value)
|
||||
}
|
||||
|
||||
a = append(a[:index+1], a[index:]...)
|
||||
a[index] = value
|
||||
return a
|
||||
}
|
||||
|
||||
func tryGetSelExprFromStmt(stmt ast.Stmt, lhs, rhs string) (selExpr *ast.SelectorExpr, args []ast.Expr) {
|
||||
|
||||
exprStmt, ok := stmt.(*ast.ExprStmt)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
callExpr, ok := exprStmt.X.(*ast.CallExpr)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
args = callExpr.Args
|
||||
|
||||
selExpr, ok = callExpr.Fun.(*ast.SelectorExpr)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if selExprIs(selExpr, lhs, rhs) {
|
||||
return selExpr, args
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (p *processor) genHasGenChecksOnOriginalFuncsNodeProcessor(c *astutil.Cursor) bool {
|
||||
|
||||
n := c.Node()
|
||||
@ -434,14 +583,14 @@ func funcCallHasLhsName(selExpr *ast.SelectorExpr, pkgName string) bool {
|
||||
return ok && pkgIdent.Name == pkgName
|
||||
}
|
||||
|
||||
func selExprIs(selExpr *ast.SelectorExpr, pkgName, typeName string) bool {
|
||||
func selExprIs(selExpr *ast.SelectorExpr, lhs, rhs string) bool {
|
||||
|
||||
pkgIdentExpr, ok := selExpr.X.(*ast.Ident)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return pkgIdentExpr.Name == pkgName && selExpr.Sel.Name == typeName
|
||||
return pkgIdentExpr.Name == lhs && selExpr.Sel.Name == rhs
|
||||
}
|
||||
|
||||
// func filter[T any](arr []T, where func(x T) bool) []T {
|
||||
@ -485,6 +634,7 @@ func writeAst(fName, topComment string, fset *token.FileSet, node any) {
|
||||
|
||||
b, err := imports.Process(fName, nil, nil)
|
||||
if err != nil {
|
||||
format.Node(os.Stdout, fset, node)
|
||||
panic("Failed to process imports on file " + fName + ". Err: " + err.Error())
|
||||
}
|
||||
|
||||
|
||||
44
main.cogo.go
44
main.cogo.go
@ -6,33 +6,67 @@ import "github.com/bloeys/cogo/cogo"
|
||||
func test_cogo(c *cogo.Coroutine[int, int]) (out int) {
|
||||
switch c.State {
|
||||
case 0:
|
||||
switch c.SubState {
|
||||
default:
|
||||
}
|
||||
|
||||
println("Tick 1")
|
||||
println("\nTick 1")
|
||||
c.State++
|
||||
c.SubState = -1
|
||||
return 1
|
||||
case 1:
|
||||
switch c.SubState {
|
||||
default:
|
||||
case 1299498081:
|
||||
goto cogo_1299498081
|
||||
}
|
||||
|
||||
println("Tick 2")
|
||||
if c.In > 1 {
|
||||
println("\nTick 1.5")
|
||||
{
|
||||
c.State = 1
|
||||
c.SubState = 1299498081
|
||||
return c.In
|
||||
}
|
||||
}
|
||||
cogo_1299498081:
|
||||
;
|
||||
|
||||
println("\nTick 2")
|
||||
c.State++
|
||||
c.SubState = -1
|
||||
return 2
|
||||
case 2:
|
||||
switch c.SubState {
|
||||
default:
|
||||
}
|
||||
|
||||
println("Tick 3")
|
||||
println("\nTick 3")
|
||||
c.State++
|
||||
c.SubState = -1
|
||||
return 3
|
||||
case 3:
|
||||
switch c.SubState {
|
||||
default:
|
||||
}
|
||||
|
||||
println("Tick 4")
|
||||
println("\nTick 4")
|
||||
c.State++
|
||||
c.SubState = -1
|
||||
return 4
|
||||
case 4:
|
||||
switch c.SubState {
|
||||
default:
|
||||
}
|
||||
c.State = -1
|
||||
c.SubState = -1
|
||||
|
||||
println("Tick before end")
|
||||
println("\nTick before end")
|
||||
|
||||
return out
|
||||
default:
|
||||
c.State = -1
|
||||
c.SubState = -1
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
16
main.go
16
main.go
@ -16,19 +16,24 @@ func test(c *cogo.Coroutine[int, int]) (out int) {
|
||||
|
||||
c.Begin()
|
||||
|
||||
println("Tick 1")
|
||||
println("\nTick 1")
|
||||
c.Yield(1)
|
||||
|
||||
println("Tick 2")
|
||||
if c.In > 1 {
|
||||
println("\nTick 1.5")
|
||||
c.Yield(c.In)
|
||||
}
|
||||
|
||||
println("\nTick 2")
|
||||
c.Yield(2)
|
||||
|
||||
println("Tick 3")
|
||||
println("\nTick 3")
|
||||
c.Yield(3)
|
||||
|
||||
println("Tick 4")
|
||||
println("\nTick 4")
|
||||
c.Yield(4)
|
||||
|
||||
println("Tick before end")
|
||||
println("\nTick before end")
|
||||
|
||||
return out
|
||||
}
|
||||
@ -40,6 +45,7 @@ func main() {
|
||||
In: 0,
|
||||
}
|
||||
|
||||
c.In = 5
|
||||
for out, done := c.Tick(); !done; out, done = c.Tick() {
|
||||
println(out)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user