From 0749bb54847e17fb7a2376e392871049f44793e7 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Thu, 12 Dec 2024 15:31:58 -0800 Subject: [PATCH] Refactor getpatch/getdiff functions and remove fallback automatically --- modules/git/repo_compare.go | 90 ++++++++++++-------------------- modules/git/repo_compare_test.go | 27 +++++++++- services/pull/patch.go | 18 +++++-- 3 files changed, 73 insertions(+), 62 deletions(-) diff --git a/modules/git/repo_compare.go b/modules/git/repo_compare.go index 16fcdcf4c8f..5d7fbe6df64 100644 --- a/modules/git/repo_compare.go +++ b/modules/git/repo_compare.go @@ -233,72 +233,63 @@ func parseDiffStat(stdout string) (numFiles, totalAdditions, totalDeletions int, return numFiles, totalAdditions, totalDeletions, err } -// GetDiffOrPatch generates either diff or formatted patch data between given revisions -func (repo *Repository) GetDiffOrPatch(base, head string, w io.Writer, patch, binary bool) error { - if patch { - return repo.GetPatch(base, head, w) +func parseCompareArgs(compareArgs string) (args []string) { + parts := strings.Split(compareArgs, "...") + if len(parts) == 2 { + return []string{compareArgs} } - if binary { - return repo.GetDiffBinary(base, head, w) + parts = strings.Split(compareArgs, "..") + if len(parts) == 2 { + return parts } - return repo.GetDiff(base, head, w) + parts = strings.Fields(compareArgs) + if len(parts) == 2 { + return parts + } + + return nil } // GetDiff generates and returns patch data between given revisions, optimized for human readability -func (repo *Repository) GetDiff(base, head string, w io.Writer) error { +func (repo *Repository) GetDiff(compareArgs string, w io.Writer) error { + args := parseCompareArgs(compareArgs) + if len(args) == 0 { + return fmt.Errorf("invalid compareArgs: %s", compareArgs) + } stderr := new(bytes.Buffer) - err := NewCommand(repo.Ctx, "diff", "-p").AddDynamicArguments(base + "..." + head). + return NewCommand(repo.Ctx, "diff", "-p").AddDynamicArguments(args...). Run(&RunOpts{ Dir: repo.Path, Stdout: w, Stderr: stderr, }) - if err != nil && bytes.Contains(stderr.Bytes(), []byte("no merge base")) { - return NewCommand(repo.Ctx, "diff", "-p").AddDynamicArguments(base, head). - Run(&RunOpts{ - Dir: repo.Path, - Stdout: w, - }) - } - return err } // GetDiffBinary generates and returns patch data between given revisions, including binary diffs. -func (repo *Repository) GetDiffBinary(base, head string, w io.Writer) error { - stderr := new(bytes.Buffer) - err := NewCommand(repo.Ctx, "diff", "-p", "--binary", "--histogram").AddDynamicArguments(base + "..." + head). - Run(&RunOpts{ - Dir: repo.Path, - Stdout: w, - Stderr: stderr, - }) - if err != nil && bytes.Contains(stderr.Bytes(), []byte("no merge base")) { - return NewCommand(repo.Ctx, "diff", "-p", "--binary", "--histogram").AddDynamicArguments(base, head). - Run(&RunOpts{ - Dir: repo.Path, - Stdout: w, - }) +func (repo *Repository) GetDiffBinary(compareArgs string, w io.Writer) error { + args := parseCompareArgs(compareArgs) + if len(args) == 0 { + return fmt.Errorf("invalid compareArgs: %s", compareArgs) } - return err + return NewCommand(repo.Ctx, "diff", "-p", "--binary", "--histogram").AddDynamicArguments(args...).Run(&RunOpts{ + Dir: repo.Path, + Stdout: w, + }) } // GetPatch generates and returns format-patch data between given revisions, able to be used with `git apply` -func (repo *Repository) GetPatch(base, head string, w io.Writer) error { +func (repo *Repository) GetPatch(compareArgs string, w io.Writer) error { + args := parseCompareArgs(compareArgs) + if len(args) == 0 { + return fmt.Errorf("invalid compareArgs: %s", compareArgs) + } stderr := new(bytes.Buffer) - err := NewCommand(repo.Ctx, "format-patch", "--binary", "--stdout").AddDynamicArguments(base + "..." + head). + return NewCommand(repo.Ctx, "format-patch", "--binary", "--stdout").AddDynamicArguments(args...). Run(&RunOpts{ Dir: repo.Path, Stdout: w, Stderr: stderr, }) - if err != nil && bytes.Contains(stderr.Bytes(), []byte("no merge base")) { - return NewCommand(repo.Ctx, "format-patch", "--binary", "--stdout").AddDynamicArguments(base, head). - Run(&RunOpts{ - Dir: repo.Path, - Stdout: w, - }) - } - return err } // GetFilesChangedBetween returns a list of all files that have been changed between the given commits @@ -329,21 +320,6 @@ func (repo *Repository) GetFilesChangedBetween(base, head string) ([]string, err return split, err } -// GetDiffFromMergeBase generates and return patch data from merge base to head -func (repo *Repository) GetDiffFromMergeBase(base, head string, w io.Writer) error { - stderr := new(bytes.Buffer) - err := NewCommand(repo.Ctx, "diff", "-p", "--binary").AddDynamicArguments(base + "..." + head). - Run(&RunOpts{ - Dir: repo.Path, - Stdout: w, - Stderr: stderr, - }) - if err != nil && bytes.Contains(stderr.Bytes(), []byte("no merge base")) { - return repo.GetDiffBinary(base, head, w) - } - return err -} - // ReadPatchCommit will check if a diff patch exists and return stats func (repo *Repository) ReadPatchCommit(prID int64) (commitSHA string, err error) { // Migrated repositories download patches to "pulls" location diff --git a/modules/git/repo_compare_test.go b/modules/git/repo_compare_test.go index 99838731867..d4597bd9486 100644 --- a/modules/git/repo_compare_test.go +++ b/modules/git/repo_compare_test.go @@ -12,6 +12,31 @@ import ( "github.com/stretchr/testify/assert" ) +func Test_parseCompareArgs(t *testing.T) { + testCases := []struct { + compareString string + expected []string + }{ + { + "master..develop", + []string{"master", "develop"}, + }, + { + "master HEAD", + []string{"master", "HEAD"}, + }, + { + "HEAD...develop", + []string{"HEAD...develop"}, + }, + } + + for _, tc := range testCases { + args := parseCompareArgs(tc.compareString) + assert.Equal(t, tc.expected, args) + } +} + func TestGetFormatPatch(t *testing.T) { bareRepo1Path := filepath.Join(testReposDir, "repo1_bare") clonedPath, err := cloneRepo(t, bareRepo1Path) @@ -28,7 +53,7 @@ func TestGetFormatPatch(t *testing.T) { defer repo.Close() rd := &bytes.Buffer{} - err = repo.GetPatch("8d92fc95^", "8d92fc95", rd) + err = repo.GetPatch("8d92fc95^...8d92fc95", rd) if err != nil { assert.NoError(t, err) return diff --git a/services/pull/patch.go b/services/pull/patch.go index 0934a86c89a..296a84bb429 100644 --- a/services/pull/patch.go +++ b/services/pull/patch.go @@ -42,9 +42,19 @@ func DownloadDiffOrPatch(ctx context.Context, pr *issues_model.PullRequest, w io } defer closer.Close() - if err := gitRepo.GetDiffOrPatch(pr.MergeBase, pr.GetGitRefName(), w, patch, binary); err != nil { - log.Error("Unable to get patch file from %s to %s in %s Error: %v", pr.MergeBase, pr.HeadBranch, pr.BaseRepo.FullName(), err) - return fmt.Errorf("Unable to get patch file from %s to %s in %s Error: %w", pr.MergeBase, pr.HeadBranch, pr.BaseRepo.FullName(), err) + compareString := pr.MergeBase + "..." + pr.GetGitRefName() + switch { + case patch: + err = gitRepo.GetPatch(compareString, w) + case binary: + err = gitRepo.GetDiffBinary(compareString, w) + default: + err = gitRepo.GetDiff(compareString, w) + } + + if err != nil { + log.Error("unable to get patch file from %s to %s in %s Error: %v", pr.MergeBase, pr.HeadBranch, pr.BaseRepo.FullName(), err) + return fmt.Errorf("unable to get patch file from %s to %s in %s Error: %w", pr.MergeBase, pr.HeadBranch, pr.BaseRepo.FullName(), err) } return nil } @@ -355,7 +365,7 @@ func checkConflicts(ctx context.Context, pr *issues_model.PullRequest, gitRepo * _ = util.Remove(tmpPatchFile.Name()) }() - if err := gitRepo.GetDiffBinary(pr.MergeBase, "tracking", tmpPatchFile); err != nil { + if err := gitRepo.GetDiffBinary(pr.MergeBase+"...tracking", tmpPatchFile); err != nil { tmpPatchFile.Close() log.Error("Unable to get patch file from %s to %s in %s Error: %v", pr.MergeBase, pr.HeadBranch, pr.BaseRepo.FullName(), err) return false, fmt.Errorf("unable to get patch file from %s to %s in %s Error: %w", pr.MergeBase, pr.HeadBranch, pr.BaseRepo.FullName(), err)