Maybe valid code?

This commit is contained in:
bloeys
2022-10-27 03:33:54 +04:00
parent 80ce751286
commit ea05c65c93
3 changed files with 153 additions and 81 deletions

View File

@ -1,14 +1,28 @@
package cogo package cogo
// func Tick(c any) { type CoroutineFunc[InT, OutT any] func(c *Coroutine[InT, OutT]) (out OutT)
// }
func Begin() { type Coroutine[InT, OutT any] struct {
State int32
In InT
Func CoroutineFunc[InT, OutT]
} }
func Yield[T any](out T) { func (c *Coroutine[InT, OutT]) Begin() {
} }
func End() { func (c *Coroutine[InT, OutT]) Tick() (out OutT, done bool) {
if c.State == -1 {
return out, true
}
out = c.Func(c)
return out, c.State == -1
}
func (c *Coroutine[InT, OutT]) Yield(out OutT) {
}
func (c *Coroutine[InT, OutT]) End() {
} }

View File

@ -11,6 +11,8 @@ import (
"golang.org/x/tools/go/packages" "golang.org/x/tools/go/packages"
) )
const cogoSwitchLbl = "cogoSwitchLbl"
func printDebugInfo() { func printDebugInfo() {
fmt.Printf("Running inliner on '%s'\n", os.Getenv("GOFILE")) fmt.Printf("Running inliner on '%s'\n", os.Getenv("GOFILE"))
@ -85,14 +87,15 @@ func (p *processor) processDeclNode(c *astutil.Cursor) bool {
return true return true
} }
if !funcDeclCallsCogo(funcDecl) { coroutineParamName := getCoroutineParamNameFromFuncDecl(funcDecl)
if coroutineParamName == "" || !funcDeclCallsCoroutineBegin(funcDecl, coroutineParamName) {
return false return false
} }
beginBodyListIndex := -1 beginBodyListIndex := -1
lastCaseEndBodyListIndex := -1 lastCaseEndBodyListIndex := -1
switchStmt := &ast.SwitchStmt{ switchStmt := &ast.SwitchStmt{
Tag: ast.NewIdent("c.state"), Tag: ast.NewIdent(coroutineParamName + ".State"),
Body: &ast.BlockStmt{ Body: &ast.BlockStmt{
List: []ast.Stmt{}, List: []ast.Stmt{},
}, },
@ -100,7 +103,7 @@ func (p *processor) processDeclNode(c *astutil.Cursor) bool {
for i, stmt := range funcDecl.Body.List { for i, stmt := range funcDecl.Body.List {
var cogoFuncCallExpr *ast.SelectorExpr var cogoFuncSelExpr *ast.SelectorExpr
// ifStmt, ifStmtOk := stmt.(*ast.IfStmt) // ifStmt, ifStmtOk := stmt.(*ast.IfStmt)
// if ifStmtOk { // if ifStmtOk {
@ -119,44 +122,89 @@ func (p *processor) processDeclNode(c *astutil.Cursor) bool {
continue continue
} }
cogoFuncCallExpr, exprStmtOk = callExpr.Fun.(*ast.SelectorExpr) cogoFuncSelExpr, exprStmtOk = callExpr.Fun.(*ast.SelectorExpr)
if !exprStmtOk { if !exprStmtOk {
continue continue
} }
if !funcCallHasPkgName(cogoFuncCallExpr, "cogo") { if !funcCallHasLhsName(cogoFuncSelExpr, coroutineParamName) {
continue continue
} }
cogoFuncCallLineNum := p.fset.File(cogoFuncCallExpr.Pos()).Line(cogoFuncCallExpr.Pos()) // cogoFuncCallLineNum := p.fset.File(cogoFuncSelExpr.Pos()).Line(cogoFuncSelExpr.Pos())
fmt.Printf("Found: '%+v' at line %d\n", cogoFuncCallExpr, cogoFuncCallLineNum)
// Now that we found a call to cogo decide what to do // Now that we found a call to cogo decide what to do
if cogoFuncCallExpr.Sel.Name == "Begin" { if cogoFuncSelExpr.Sel.Name == "Begin" {
beginBodyListIndex = i beginBodyListIndex = i
lastCaseEndBodyListIndex = i lastCaseEndBodyListIndex = i
continue continue
} else if cogoFuncCallExpr.Sel.Name == "Yield" || cogoFuncCallExpr.Sel.Name == "End" { } else if cogoFuncSelExpr.Sel.Name == "Yield" {
// Add everything from the last begin/yield until this yield into a case // Add everything from the last begin/yield until this yield into a case
stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i] stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i]
switchStmt.Body.List = append(switchStmt.Body.List, getCaseWithStmts(
stmtsSinceLastCogo, caseStmts := make([]ast.Stmt, 0, len(stmtsSinceLastCogo)+2)
[]ast.Expr{ast.NewIdent(fmt.Sprint(cogoFuncCallLineNum))}, caseStmts = append(caseStmts, stmtsSinceLastCogo...)
)) caseStmts = append(caseStmts,
&ast.IncDecStmt{
Tok: token.INC,
X: ast.NewIdent(coroutineParamName + ".State"),
},
&ast.ReturnStmt{
Results: callExpr.Args,
},
)
switchStmt.Body.List = append(switchStmt.Body.List,
getCaseWithStmts(
[]ast.Expr{ast.NewIdent(fmt.Sprint(len(switchStmt.Body.List)))},
caseStmts,
),
)
lastCaseEndBodyListIndex = i
} else if cogoFuncSelExpr.Sel.Name == "End" {
// Add everything from the last begin/yield until this yield into a case
stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i]
stmtsAfterEnd := funcDecl.Body.List[i+1:]
caseStmts := make([]ast.Stmt, 0, len(stmtsSinceLastCogo)+len(stmtsAfterEnd)+1)
caseStmts = append(caseStmts, stmtsSinceLastCogo...)
caseStmts = append(caseStmts,
&ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".State")},
Tok: token.ASSIGN,
Rhs: []ast.Expr{ast.NewIdent("-1")},
},
)
caseStmts = append(caseStmts, stmtsAfterEnd...)
switchStmt.Body.List = append(switchStmt.Body.List,
getCaseWithStmts(
[]ast.Expr{ast.NewIdent(fmt.Sprint(len(switchStmt.Body.List)))},
caseStmts,
),
)
lastCaseEndBodyListIndex = i lastCaseEndBodyListIndex = i
} }
} }
funcDecl.Body.List = funcDecl.Body.List[:beginBodyListIndex] funcDecl.Body.List = funcDecl.Body.List[:beginBodyListIndex]
funcDecl.Body.List = append(funcDecl.Body.List, switchStmt) funcDecl.Body.List = append(funcDecl.Body.List,
// &ast.LabeledStmt{
// Label: ast.NewIdent(cogoSwitchLbl),
// Stmt: switchStmt,
// },
switchStmt,
)
return true return true
} }
func funcDeclCallsCogo(fd *ast.FuncDecl) bool { func funcDeclCallsCoroutineBegin(fd *ast.FuncDecl, coroutineParamName string) bool {
if fd.Body == nil || len(fd.Body.List) == 0 { if fd.Body == nil || len(fd.Body.List) == 0 {
return false return false
@ -180,35 +228,81 @@ func funcDeclCallsCogo(fd *ast.FuncDecl) bool {
continue continue
} }
return funcCallHasPkgName(pkgFuncCallExpr, "cogo") if funcCallHasLhsName(pkgFuncCallExpr, coroutineParamName) {
return true
}
} }
return false return false
} }
func funcCallHasPkgName(selExpr *ast.SelectorExpr, pkgName string) bool { func getCoroutineParamNameFromFuncDecl(fd *ast.FuncDecl) string {
for _, p := range fd.Type.Params.List {
ptrExpr, ok := p.Type.(*ast.StarExpr)
if !ok {
continue
}
// indexList because coroutine type takes multiple generic parameters, creating an indexed list
indexListExpr, ok := ptrExpr.X.(*ast.IndexListExpr)
if !ok {
continue
}
selExpr, ok := indexListExpr.X.(*ast.SelectorExpr)
if !ok {
continue
}
if !selExprIs(selExpr, "cogo", "Coroutine") {
continue
}
return p.Names[0].Name
}
return ""
}
// func getIdentNameFromExprOrPanic(e ast.Expr) string {
// return e.(*ast.Ident).Name
// }
func funcCallHasLhsName(selExpr *ast.SelectorExpr, pkgName string) bool {
pkgIdent, ok := selExpr.X.(*ast.Ident) pkgIdent, ok := selExpr.X.(*ast.Ident)
return ok && pkgIdent.Name == pkgName return ok && pkgIdent.Name == pkgName
} }
func filter[T any](arr []T, where func(x T) bool) []T { func selExprIs(selExpr *ast.SelectorExpr, pkgName, typeName string) bool {
out := []T{} pkgIdentExpr, ok := selExpr.X.(*ast.Ident)
for i := 0; i < len(arr); i++ { if !ok {
return false
if !where(arr[i]) {
continue
} }
out = append(out, arr[i]) return pkgIdentExpr.Name == pkgName && selExpr.Sel.Name == typeName
} }
return out // func filter[T any](arr []T, where func(x T) bool) []T {
}
func getCaseWithStmts(stmts []ast.Stmt, conditions []ast.Expr) *ast.CaseClause { // out := []T{}
// for i := 0; i < len(arr); i++ {
// if !where(arr[i]) {
// continue
// }
// out = append(out, arr[i])
// }
// return out
// }
func getCaseWithStmts(caseConditions []ast.Expr, stmts []ast.Stmt) *ast.CaseClause {
return &ast.CaseClause{ return &ast.CaseClause{
List: conditions, List: caseConditions,
Body: stmts, Body: stmts,
} }
} }

54
main.go
View File

@ -9,51 +9,28 @@ import (
"github.com/bloeys/cogo/cogo" "github.com/bloeys/cogo/cogo"
) )
type CoroutineFunc[InT, OutT any] func(c *Coroutine[InT, OutT]) (out OutT)
type Coroutine[InT, OutT any] struct {
State int32
In InT
Func CoroutineFunc[InT, OutT]
}
func (c *Coroutine[InT, OutT]) Tick() (out OutT, done bool) {
if c.State == -1 {
return out, true
}
out = c.Func(c)
return out, c.State == -1
}
// func (c *Coroutine[InT, OutT]) Yield(out OutT) {
// }
func (c *Coroutine[InT, OutT]) Break() {
}
func Wow() { func Wow() {
println("wow") println("wow")
} }
func test(c *Coroutine[int, int]) (out int) { func test(c *cogo.Coroutine[int, int]) (out int) {
cogo.Begin() c.Begin()
println("Tick 1") println("Tick 1")
cogo.Yield(1) c.Yield(1)
println("Tick 2") println("Tick 2")
cogo.Yield(2) c.Yield(2)
println("Tick 3") println("Tick 3")
cogo.Yield(3) c.Yield(3)
println("Tick 4") println("Tick 4")
cogo.Yield(4) c.Yield(4)
cogo.End() println("Tick before end")
c.End()
// switch c.State { // switch c.State {
// case 0: // case 0:
@ -97,20 +74,7 @@ func test(c *Coroutine[int, int]) (out int) {
func main() { func main() {
x := 1 c := &cogo.Coroutine[int, int]{
switch_start:
switch x {
case 1:
println(1)
x = 3
goto switch_start
case 2:
println(2)
case 3:
println(3)
}
return
c := &Coroutine[int, int]{
Func: test, Func: test,
In: 0, In: 0,
} }