From 1d4451dac2ab6065e9ff28c0ed1c0065d32cbb7d Mon Sep 17 00:00:00 2001 From: bloeys Date: Fri, 4 Nov 2022 02:54:10 +0400 Subject: [PATCH] Better err msgs+better detection of coroutines --- main.go | 54 ++++++++++++++++++++++-------------------------------- 1 file changed, 22 insertions(+), 32 deletions(-) diff --git a/main.go b/main.go index 4f1d43a..c936126 100755 --- a/main.go +++ b/main.go @@ -84,10 +84,6 @@ func genHasGenChecksOnOriginalFuncs(cwd string) { panic(err) } - // if len(pkgs) != 1 { - // panic(fmt.Sprintf("expected to find one package but found %d", len(pkgs))) - // } - for _, pkg := range pkgs { p := &processor{ @@ -128,10 +124,20 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { } coroutineParamName := getCoroutineParamNameFromFuncDecl(funcDecl) - if coroutineParamName == "" || !funcDeclCallsCoroutineBegin(funcDecl, coroutineParamName) { + if coroutineParamName == "" { return false } + hasBegin := funcHasSelInBody(funcDecl, coroutineParamName, "Begin") + hasYield := funcHasSelInBody(funcDecl, coroutineParamName, "Yield") + if !hasBegin && !hasYield { + return false + } + + if hasYield && !hasBegin { + panic(fmt.Sprintf("Function '%s' in file '%s' has a 'Yield()' call but no 'Begin()'. Please ensure your coroutines have 'Begin()'", funcDecl.Name.Name, p.fset.File(funcDecl.Pos()).Name())) + } + beginBodyListIndex := -1 lastCaseEndBodyListIndex := -1 switchStmt := &ast.SwitchStmt{ @@ -255,12 +261,10 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool { continue } - if !funcCallHasLhsName(cogoFuncSelExpr, coroutineParamName) { + if !selExprHasLhsName(cogoFuncSelExpr, coroutineParamName) { continue } - // cogoFuncCallLineNum := p.fset.File(cogoFuncSelExpr.Pos()).Line(cogoFuncSelExpr.Pos()) - // Now that we found a call to cogo decide what to do if cogoFuncSelExpr.Sel.Name == "Begin" { @@ -531,7 +535,7 @@ func (p *processor) genHasGenChecksOnOriginalFuncsNodeProcessor(c *astutil.Curso } coroutineParamName := getCoroutineParamNameFromFuncDecl(funcDecl) - if coroutineParamName == "" || !funcDeclCallsCoroutineBegin(funcDecl, coroutineParamName) { + if coroutineParamName == "" || !funcHasSelInBody(funcDecl, coroutineParamName, "Begin") { return false } @@ -599,7 +603,7 @@ func createStmtFromSelFuncCall(lhs, rhs string) ast.Stmt { } } -func funcDeclCallsCoroutineBegin(fd *ast.FuncDecl, coroutineParamName string) bool { +func funcHasSelInBody(fd *ast.FuncDecl, selLhs, selRhs string) bool { if fd.Body == nil || len(fd.Body.List) == 0 { return false @@ -618,12 +622,12 @@ func funcDeclCallsCoroutineBegin(fd *ast.FuncDecl, coroutineParamName string) bo continue } - pkgFuncCallExpr, ok := callExpr.Fun.(*ast.SelectorExpr) + selExpr, ok := callExpr.Fun.(*ast.SelectorExpr) if !ok { continue } - if funcCallHasLhsName(pkgFuncCallExpr, coroutineParamName) { + if selExprIs(selExpr, selLhs, selRhs) { return true } } @@ -633,6 +637,11 @@ func funcDeclCallsCoroutineBegin(fd *ast.FuncDecl, coroutineParamName string) bo func getCoroutineParamNameFromFuncDecl(fd *ast.FuncDecl) string { + // If func doesn't take one parameter then it is not a coroutine + if len(fd.Type.Params.List) == 0 { + return "" + } + for _, p := range fd.Type.Params.List { ptrExpr, ok := p.Type.(*ast.StarExpr) @@ -661,11 +670,7 @@ func getCoroutineParamNameFromFuncDecl(fd *ast.FuncDecl) string { return "" } -// func getIdentNameFromExprOrPanic(e ast.Expr) string { -// return e.(*ast.Ident).Name -// } - -func funcCallHasLhsName(selExpr *ast.SelectorExpr, pkgName string) bool { +func selExprHasLhsName(selExpr *ast.SelectorExpr, pkgName string) bool { pkgIdent, ok := selExpr.X.(*ast.Ident) return ok && pkgIdent.Name == pkgName } @@ -680,21 +685,6 @@ func selExprIs(selExpr *ast.SelectorExpr, lhs, rhs string) bool { return pkgIdentExpr.Name == lhs && selExpr.Sel.Name == rhs } -// func filter[T any](arr []T, where func(x T) bool) []T { - -// out := []T{} -// for i := 0; i < len(arr); i++ { - -// if !where(arr[i]) { -// continue -// } - -// out = append(out, arr[i]) -// } - -// return out -// } - func getCaseWithStmts(caseConditions []ast.Expr, stmts []ast.Stmt) *ast.CaseClause { return &ast.CaseClause{ List: caseConditions,