From ea05c65c93a5b48bc8f2b7f46a4b95ac618f35e0 Mon Sep 17 00:00:00 2001 From: bloeys Date: Thu, 27 Oct 2022 03:33:54 +0400 Subject: [PATCH] Maybe valid code? --- cogo/cogo.go | 26 ++++++-- inliner/main.go | 154 ++++++++++++++++++++++++++++++++++++++---------- main.go | 54 +++-------------- 3 files changed, 153 insertions(+), 81 deletions(-) diff --git a/cogo/cogo.go b/cogo/cogo.go index 7c9c459..8d1b6cf 100755 --- a/cogo/cogo.go +++ b/cogo/cogo.go @@ -1,14 +1,28 @@ package cogo -// func Tick(c any) { -// } +type CoroutineFunc[InT, OutT any] func(c *Coroutine[InT, OutT]) (out OutT) -func Begin() { +type Coroutine[InT, OutT any] struct { + State int32 + In InT + Func CoroutineFunc[InT, OutT] } -func Yield[T any](out T) { - +func (c *Coroutine[InT, OutT]) Begin() { } -func End() { +func (c *Coroutine[InT, OutT]) Tick() (out OutT, done bool) { + + if c.State == -1 { + return out, true + } + + out = c.Func(c) + return out, c.State == -1 +} + +func (c *Coroutine[InT, OutT]) Yield(out OutT) { +} + +func (c *Coroutine[InT, OutT]) End() { } diff --git a/inliner/main.go b/inliner/main.go index fa15e08..f33beb2 100755 --- a/inliner/main.go +++ b/inliner/main.go @@ -11,6 +11,8 @@ import ( "golang.org/x/tools/go/packages" ) +const cogoSwitchLbl = "cogoSwitchLbl" + func printDebugInfo() { fmt.Printf("Running inliner on '%s'\n", os.Getenv("GOFILE")) @@ -85,14 +87,15 @@ func (p *processor) processDeclNode(c *astutil.Cursor) bool { return true } - if !funcDeclCallsCogo(funcDecl) { + coroutineParamName := getCoroutineParamNameFromFuncDecl(funcDecl) + if coroutineParamName == "" || !funcDeclCallsCoroutineBegin(funcDecl, coroutineParamName) { return false } beginBodyListIndex := -1 lastCaseEndBodyListIndex := -1 switchStmt := &ast.SwitchStmt{ - Tag: ast.NewIdent("c.state"), + Tag: ast.NewIdent(coroutineParamName + ".State"), Body: &ast.BlockStmt{ List: []ast.Stmt{}, }, @@ -100,7 +103,7 @@ func (p *processor) processDeclNode(c *astutil.Cursor) bool { for i, stmt := range funcDecl.Body.List { - var cogoFuncCallExpr *ast.SelectorExpr + var cogoFuncSelExpr *ast.SelectorExpr // ifStmt, ifStmtOk := stmt.(*ast.IfStmt) // if ifStmtOk { @@ -119,44 +122,89 @@ func (p *processor) processDeclNode(c *astutil.Cursor) bool { continue } - cogoFuncCallExpr, exprStmtOk = callExpr.Fun.(*ast.SelectorExpr) + cogoFuncSelExpr, exprStmtOk = callExpr.Fun.(*ast.SelectorExpr) if !exprStmtOk { continue } - if !funcCallHasPkgName(cogoFuncCallExpr, "cogo") { + if !funcCallHasLhsName(cogoFuncSelExpr, coroutineParamName) { continue } - cogoFuncCallLineNum := p.fset.File(cogoFuncCallExpr.Pos()).Line(cogoFuncCallExpr.Pos()) - fmt.Printf("Found: '%+v' at line %d\n", cogoFuncCallExpr, cogoFuncCallLineNum) + // cogoFuncCallLineNum := p.fset.File(cogoFuncSelExpr.Pos()).Line(cogoFuncSelExpr.Pos()) // Now that we found a call to cogo decide what to do - if cogoFuncCallExpr.Sel.Name == "Begin" { + if cogoFuncSelExpr.Sel.Name == "Begin" { beginBodyListIndex = i lastCaseEndBodyListIndex = i continue - } else if cogoFuncCallExpr.Sel.Name == "Yield" || cogoFuncCallExpr.Sel.Name == "End" { + } else if cogoFuncSelExpr.Sel.Name == "Yield" { // Add everything from the last begin/yield until this yield into a case - stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i] - switchStmt.Body.List = append(switchStmt.Body.List, getCaseWithStmts( - stmtsSinceLastCogo, - []ast.Expr{ast.NewIdent(fmt.Sprint(cogoFuncCallLineNum))}, - )) + + caseStmts := make([]ast.Stmt, 0, len(stmtsSinceLastCogo)+2) + caseStmts = append(caseStmts, stmtsSinceLastCogo...) + caseStmts = append(caseStmts, + &ast.IncDecStmt{ + Tok: token.INC, + X: ast.NewIdent(coroutineParamName + ".State"), + }, + &ast.ReturnStmt{ + Results: callExpr.Args, + }, + ) + + switchStmt.Body.List = append(switchStmt.Body.List, + getCaseWithStmts( + []ast.Expr{ast.NewIdent(fmt.Sprint(len(switchStmt.Body.List)))}, + caseStmts, + ), + ) + + lastCaseEndBodyListIndex = i + + } else if cogoFuncSelExpr.Sel.Name == "End" { + + // Add everything from the last begin/yield until this yield into a case + stmtsSinceLastCogo := funcDecl.Body.List[lastCaseEndBodyListIndex+1 : i] + stmtsAfterEnd := funcDecl.Body.List[i+1:] + + caseStmts := make([]ast.Stmt, 0, len(stmtsSinceLastCogo)+len(stmtsAfterEnd)+1) + caseStmts = append(caseStmts, stmtsSinceLastCogo...) + caseStmts = append(caseStmts, + &ast.AssignStmt{ + Lhs: []ast.Expr{ast.NewIdent(coroutineParamName + ".State")}, + Tok: token.ASSIGN, + Rhs: []ast.Expr{ast.NewIdent("-1")}, + }, + ) + caseStmts = append(caseStmts, stmtsAfterEnd...) + + switchStmt.Body.List = append(switchStmt.Body.List, + getCaseWithStmts( + []ast.Expr{ast.NewIdent(fmt.Sprint(len(switchStmt.Body.List)))}, + caseStmts, + ), + ) lastCaseEndBodyListIndex = i } } funcDecl.Body.List = funcDecl.Body.List[:beginBodyListIndex] - funcDecl.Body.List = append(funcDecl.Body.List, switchStmt) + funcDecl.Body.List = append(funcDecl.Body.List, + // &ast.LabeledStmt{ + // Label: ast.NewIdent(cogoSwitchLbl), + // Stmt: switchStmt, + // }, + switchStmt, + ) return true } -func funcDeclCallsCogo(fd *ast.FuncDecl) bool { +func funcDeclCallsCoroutineBegin(fd *ast.FuncDecl, coroutineParamName string) bool { if fd.Body == nil || len(fd.Body.List) == 0 { return false @@ -180,35 +228,81 @@ func funcDeclCallsCogo(fd *ast.FuncDecl) bool { continue } - return funcCallHasPkgName(pkgFuncCallExpr, "cogo") + if funcCallHasLhsName(pkgFuncCallExpr, coroutineParamName) { + return true + } } return false } -func funcCallHasPkgName(selExpr *ast.SelectorExpr, pkgName string) bool { +func getCoroutineParamNameFromFuncDecl(fd *ast.FuncDecl) string { + + for _, p := range fd.Type.Params.List { + + ptrExpr, ok := p.Type.(*ast.StarExpr) + if !ok { + continue + } + + // indexList because coroutine type takes multiple generic parameters, creating an indexed list + indexListExpr, ok := ptrExpr.X.(*ast.IndexListExpr) + if !ok { + continue + } + + selExpr, ok := indexListExpr.X.(*ast.SelectorExpr) + if !ok { + continue + } + + if !selExprIs(selExpr, "cogo", "Coroutine") { + continue + } + + return p.Names[0].Name + } + + return "" +} + +// func getIdentNameFromExprOrPanic(e ast.Expr) string { +// return e.(*ast.Ident).Name +// } + +func funcCallHasLhsName(selExpr *ast.SelectorExpr, pkgName string) bool { pkgIdent, ok := selExpr.X.(*ast.Ident) return ok && pkgIdent.Name == pkgName } -func filter[T any](arr []T, where func(x T) bool) []T { +func selExprIs(selExpr *ast.SelectorExpr, pkgName, typeName string) bool { - out := []T{} - for i := 0; i < len(arr); i++ { - - if !where(arr[i]) { - continue - } - - out = append(out, arr[i]) + pkgIdentExpr, ok := selExpr.X.(*ast.Ident) + if !ok { + return false } - return out + return pkgIdentExpr.Name == pkgName && selExpr.Sel.Name == typeName } -func getCaseWithStmts(stmts []ast.Stmt, conditions []ast.Expr) *ast.CaseClause { +// func filter[T any](arr []T, where func(x T) bool) []T { + +// out := []T{} +// for i := 0; i < len(arr); i++ { + +// if !where(arr[i]) { +// continue +// } + +// out = append(out, arr[i]) +// } + +// return out +// } + +func getCaseWithStmts(caseConditions []ast.Expr, stmts []ast.Stmt) *ast.CaseClause { return &ast.CaseClause{ - List: conditions, + List: caseConditions, Body: stmts, } } diff --git a/main.go b/main.go index d118a61..61bfcf4 100755 --- a/main.go +++ b/main.go @@ -9,51 +9,28 @@ import ( "github.com/bloeys/cogo/cogo" ) -type CoroutineFunc[InT, OutT any] func(c *Coroutine[InT, OutT]) (out OutT) - -type Coroutine[InT, OutT any] struct { - State int32 - In InT - Func CoroutineFunc[InT, OutT] -} - -func (c *Coroutine[InT, OutT]) Tick() (out OutT, done bool) { - - if c.State == -1 { - return out, true - } - - out = c.Func(c) - return out, c.State == -1 -} - -// func (c *Coroutine[InT, OutT]) Yield(out OutT) { -// } - -func (c *Coroutine[InT, OutT]) Break() { -} - func Wow() { println("wow") } -func test(c *Coroutine[int, int]) (out int) { +func test(c *cogo.Coroutine[int, int]) (out int) { - cogo.Begin() + c.Begin() println("Tick 1") - cogo.Yield(1) + c.Yield(1) println("Tick 2") - cogo.Yield(2) + c.Yield(2) println("Tick 3") - cogo.Yield(3) + c.Yield(3) println("Tick 4") - cogo.Yield(4) + c.Yield(4) - cogo.End() + println("Tick before end") + c.End() // switch c.State { // case 0: @@ -97,20 +74,7 @@ func test(c *Coroutine[int, int]) (out int) { func main() { - x := 1 -switch_start: - switch x { - case 1: - println(1) - x = 3 - goto switch_start - case 2: - println(2) - case 3: - println(3) - } - return - c := &Coroutine[int, int]{ + c := &cogo.Coroutine[int, int]{ Func: test, In: 0, }