From 714ab71ddc4260937b1480519d453d2dc4e77dd6 Mon Sep 17 00:00:00 2001 From: zeripath Date: Wed, 2 Sep 2020 18:49:25 +0100 Subject: [PATCH] Ensure that all migration requests are cancellable (#12669) * Ensure that all migration requests are cancellable Signed-off-by: Andrew Thornton * Use WithContext as RequestWithContext is go 1.14 Signed-off-by: Andrew Thornton Co-authored-by: techknowlogick --- modules/migrations/base/downloader.go | 61 ++++++++++++++++++++++----- modules/migrations/gitea_test.go | 3 +- modules/migrations/github.go | 8 ++-- modules/migrations/github_test.go | 3 +- modules/migrations/gitlab.go | 41 ++++++++++-------- modules/migrations/gitlab_test.go | 3 +- modules/migrations/migrate.go | 6 +-- 7 files changed, 86 insertions(+), 39 deletions(-) diff --git a/modules/migrations/base/downloader.go b/modules/migrations/base/downloader.go index b692969ba50..036abf22c91 100644 --- a/modules/migrations/base/downloader.go +++ b/modules/migrations/base/downloader.go @@ -35,7 +35,7 @@ type Downloader interface { // DownloaderFactory defines an interface to match a downloader implementation and create a downloader type DownloaderFactory interface { - New(opts MigrateOptions) (Downloader, error) + New(ctx context.Context, opts MigrateOptions) (Downloader, error) GitServiceType() structs.GitServiceType } @@ -46,14 +46,16 @@ var ( // RetryDownloader retry the downloads type RetryDownloader struct { Downloader + ctx context.Context RetryTimes int // the total execute times RetryDelay int // time to delay seconds } // NewRetryDownloader creates a retry downloader -func NewRetryDownloader(downloader Downloader, retryTimes, retryDelay int) *RetryDownloader { +func NewRetryDownloader(ctx context.Context, downloader Downloader, retryTimes, retryDelay int) *RetryDownloader { return &RetryDownloader{ Downloader: downloader, + ctx: ctx, RetryTimes: retryTimes, RetryDelay: retryDelay, } @@ -61,6 +63,7 @@ func NewRetryDownloader(downloader Downloader, retryTimes, retryDelay int) *Retr // SetContext set context func (d *RetryDownloader) SetContext(ctx context.Context) { + d.ctx = ctx d.Downloader.SetContext(ctx) } @@ -75,7 +78,11 @@ func (d *RetryDownloader) GetRepoInfo() (*Repository, error) { if repo, err = d.Downloader.GetRepoInfo(); err == nil { return repo, nil } - time.Sleep(time.Second * time.Duration(d.RetryDelay)) + select { + case <-d.ctx.Done(): + return nil, d.ctx.Err() + case <-time.After(time.Second * time.Duration(d.RetryDelay)): + } } return nil, err } @@ -91,7 +98,11 @@ func (d *RetryDownloader) GetTopics() ([]string, error) { if topics, err = d.Downloader.GetTopics(); err == nil { return topics, nil } - time.Sleep(time.Second * time.Duration(d.RetryDelay)) + select { + case <-d.ctx.Done(): + return nil, d.ctx.Err() + case <-time.After(time.Second * time.Duration(d.RetryDelay)): + } } return nil, err } @@ -107,7 +118,11 @@ func (d *RetryDownloader) GetMilestones() ([]*Milestone, error) { if milestones, err = d.Downloader.GetMilestones(); err == nil { return milestones, nil } - time.Sleep(time.Second * time.Duration(d.RetryDelay)) + select { + case <-d.ctx.Done(): + return nil, d.ctx.Err() + case <-time.After(time.Second * time.Duration(d.RetryDelay)): + } } return nil, err } @@ -123,7 +138,11 @@ func (d *RetryDownloader) GetReleases() ([]*Release, error) { if releases, err = d.Downloader.GetReleases(); err == nil { return releases, nil } - time.Sleep(time.Second * time.Duration(d.RetryDelay)) + select { + case <-d.ctx.Done(): + return nil, d.ctx.Err() + case <-time.After(time.Second * time.Duration(d.RetryDelay)): + } } return nil, err } @@ -139,7 +158,11 @@ func (d *RetryDownloader) GetLabels() ([]*Label, error) { if labels, err = d.Downloader.GetLabels(); err == nil { return labels, nil } - time.Sleep(time.Second * time.Duration(d.RetryDelay)) + select { + case <-d.ctx.Done(): + return nil, d.ctx.Err() + case <-time.After(time.Second * time.Duration(d.RetryDelay)): + } } return nil, err } @@ -156,7 +179,11 @@ func (d *RetryDownloader) GetIssues(page, perPage int) ([]*Issue, bool, error) { if issues, isEnd, err = d.Downloader.GetIssues(page, perPage); err == nil { return issues, isEnd, nil } - time.Sleep(time.Second * time.Duration(d.RetryDelay)) + select { + case <-d.ctx.Done(): + return nil, false, d.ctx.Err() + case <-time.After(time.Second * time.Duration(d.RetryDelay)): + } } return nil, false, err } @@ -172,7 +199,11 @@ func (d *RetryDownloader) GetComments(issueNumber int64) ([]*Comment, error) { if comments, err = d.Downloader.GetComments(issueNumber); err == nil { return comments, nil } - time.Sleep(time.Second * time.Duration(d.RetryDelay)) + select { + case <-d.ctx.Done(): + return nil, d.ctx.Err() + case <-time.After(time.Second * time.Duration(d.RetryDelay)): + } } return nil, err } @@ -188,7 +219,11 @@ func (d *RetryDownloader) GetPullRequests(page, perPage int) ([]*PullRequest, er if prs, err = d.Downloader.GetPullRequests(page, perPage); err == nil { return prs, nil } - time.Sleep(time.Second * time.Duration(d.RetryDelay)) + select { + case <-d.ctx.Done(): + return nil, d.ctx.Err() + case <-time.After(time.Second * time.Duration(d.RetryDelay)): + } } return nil, err } @@ -204,7 +239,11 @@ func (d *RetryDownloader) GetReviews(pullRequestNumber int64) ([]*Review, error) if reviews, err = d.Downloader.GetReviews(pullRequestNumber); err == nil { return reviews, nil } - time.Sleep(time.Second * time.Duration(d.RetryDelay)) + select { + case <-d.ctx.Done(): + return nil, d.ctx.Err() + case <-time.After(time.Second * time.Duration(d.RetryDelay)): + } } return nil, err } diff --git a/modules/migrations/gitea_test.go b/modules/migrations/gitea_test.go index 02b2f0a5c9f..62c8f713226 100644 --- a/modules/migrations/gitea_test.go +++ b/modules/migrations/gitea_test.go @@ -6,6 +6,7 @@ package migrations import ( + "context" "testing" "time" @@ -26,7 +27,7 @@ func TestGiteaUploadRepo(t *testing.T) { user := models.AssertExistsAndLoadBean(t, &models.User{ID: 1}).(*models.User) var ( - downloader = NewGithubDownloaderV3("", "", "", "go-xorm", "builder") + downloader = NewGithubDownloaderV3(context.Background(), "", "", "", "go-xorm", "builder") repoName = "builder-" + time.Now().Format("2006-01-02-15-04-05") uploader = NewGiteaLocalUploader(graceful.GetManager().HammerContext(), user, user.Name, repoName) ) diff --git a/modules/migrations/github.go b/modules/migrations/github.go index eb73a7e0d40..e5cc3b82235 100644 --- a/modules/migrations/github.go +++ b/modules/migrations/github.go @@ -41,7 +41,7 @@ type GithubDownloaderV3Factory struct { } // New returns a Downloader related to this factory according MigrateOptions -func (f *GithubDownloaderV3Factory) New(opts base.MigrateOptions) (base.Downloader, error) { +func (f *GithubDownloaderV3Factory) New(ctx context.Context, opts base.MigrateOptions) (base.Downloader, error) { u, err := url.Parse(opts.CloneAddr) if err != nil { return nil, err @@ -53,7 +53,7 @@ func (f *GithubDownloaderV3Factory) New(opts base.MigrateOptions) (base.Download log.Trace("Create github downloader: %s/%s", oldOwner, oldName) - return NewGithubDownloaderV3(opts.AuthUsername, opts.AuthPassword, opts.AuthToken, oldOwner, oldName), nil + return NewGithubDownloaderV3(ctx, opts.AuthUsername, opts.AuthPassword, opts.AuthToken, oldOwner, oldName), nil } // GitServiceType returns the type of git service @@ -74,11 +74,11 @@ type GithubDownloaderV3 struct { } // NewGithubDownloaderV3 creates a github Downloader via github v3 API -func NewGithubDownloaderV3(userName, password, token, repoOwner, repoName string) *GithubDownloaderV3 { +func NewGithubDownloaderV3(ctx context.Context, userName, password, token, repoOwner, repoName string) *GithubDownloaderV3 { var downloader = GithubDownloaderV3{ userName: userName, password: password, - ctx: context.Background(), + ctx: ctx, repoOwner: repoOwner, repoName: repoName, } diff --git a/modules/migrations/github_test.go b/modules/migrations/github_test.go index 0b8c559d305..660e82d6457 100644 --- a/modules/migrations/github_test.go +++ b/modules/migrations/github_test.go @@ -6,6 +6,7 @@ package migrations import ( + "context" "os" "testing" "time" @@ -64,7 +65,7 @@ func assertLabelEqual(t *testing.T, name, color, description string, label *base func TestGitHubDownloadRepo(t *testing.T) { GithubLimitRateRemaining = 3 //Wait at 3 remaining since we could have 3 CI in // - downloader := NewGithubDownloaderV3("", "", os.Getenv("GITHUB_READ_TOKEN"), "go-gitea", "test_repo") + downloader := NewGithubDownloaderV3(context.Background(), "", "", os.Getenv("GITHUB_READ_TOKEN"), "go-gitea", "test_repo") err := downloader.RefreshRate() assert.NoError(t, err) diff --git a/modules/migrations/gitlab.go b/modules/migrations/gitlab.go index eec16d24333..c03ce89c608 100644 --- a/modules/migrations/gitlab.go +++ b/modules/migrations/gitlab.go @@ -35,7 +35,7 @@ type GitlabDownloaderFactory struct { } // New returns a Downloader related to this factory according MigrateOptions -func (f *GitlabDownloaderFactory) New(opts base.MigrateOptions) (base.Downloader, error) { +func (f *GitlabDownloaderFactory) New(ctx context.Context, opts base.MigrateOptions) (base.Downloader, error) { u, err := url.Parse(opts.CloneAddr) if err != nil { return nil, err @@ -47,7 +47,7 @@ func (f *GitlabDownloaderFactory) New(opts base.MigrateOptions) (base.Downloader log.Trace("Create gitlab downloader. BaseURL: %s RepoName: %s", baseURL, repoNameSpace) - return NewGitlabDownloader(baseURL, repoNameSpace, opts.AuthUsername, opts.AuthPassword, opts.AuthToken), nil + return NewGitlabDownloader(ctx, baseURL, repoNameSpace, opts.AuthUsername, opts.AuthPassword, opts.AuthToken), nil } // GitServiceType returns the type of git service @@ -73,7 +73,7 @@ type GitlabDownloader struct { // NewGitlabDownloader creates a gitlab Downloader via gitlab API // Use either a username/password, personal token entered into the username field, or anonymous/public access // Note: Public access only allows very basic access -func NewGitlabDownloader(baseURL, repoPath, username, password, token string) *GitlabDownloader { +func NewGitlabDownloader(ctx context.Context, baseURL, repoPath, username, password, token string) *GitlabDownloader { var gitlabClient *gitlab.Client var err error if token != "" { @@ -88,7 +88,7 @@ func NewGitlabDownloader(baseURL, repoPath, username, password, token string) *G } // Grab and store project/repo ID here, due to issues using the URL escaped path - gr, _, err := gitlabClient.Projects.GetProject(repoPath, nil, nil) + gr, _, err := gitlabClient.Projects.GetProject(repoPath, nil, nil, gitlab.WithContext(ctx)) if err != nil { log.Trace("Error retrieving project: %v", err) return nil @@ -100,7 +100,7 @@ func NewGitlabDownloader(baseURL, repoPath, username, password, token string) *G } return &GitlabDownloader{ - ctx: context.Background(), + ctx: ctx, client: gitlabClient, repoID: gr.ID, repoName: gr.Name, @@ -118,7 +118,7 @@ func (g *GitlabDownloader) GetRepoInfo() (*base.Repository, error) { return nil, errors.New("error: GitlabDownloader is nil") } - gr, _, err := g.client.Projects.GetProject(g.repoID, nil, nil) + gr, _, err := g.client.Projects.GetProject(g.repoID, nil, nil, gitlab.WithContext(g.ctx)) if err != nil { return nil, err } @@ -158,7 +158,7 @@ func (g *GitlabDownloader) GetTopics() ([]string, error) { return nil, errors.New("error: GitlabDownloader is nil") } - gr, _, err := g.client.Projects.GetProject(g.repoID, nil, nil) + gr, _, err := g.client.Projects.GetProject(g.repoID, nil, nil, gitlab.WithContext(g.ctx)) if err != nil { return nil, err } @@ -179,7 +179,7 @@ func (g *GitlabDownloader) GetMilestones() ([]*base.Milestone, error) { ListOptions: gitlab.ListOptions{ Page: i, PerPage: perPage, - }}, nil) + }}, nil, gitlab.WithContext(g.ctx)) if err != nil { return nil, err } @@ -237,7 +237,7 @@ func (g *GitlabDownloader) GetLabels() ([]*base.Label, error) { ls, _, err := g.client.Labels.ListLabels(g.repoID, &gitlab.ListLabelsOptions{ Page: i, PerPage: perPage, - }, nil) + }, nil, gitlab.WithContext(g.ctx)) if err != nil { return nil, err } @@ -288,7 +288,7 @@ func (g *GitlabDownloader) GetReleases() ([]*base.Release, error) { ls, _, err := g.client.Releases.ListReleases(g.repoID, &gitlab.ListReleasesOptions{ Page: i, PerPage: perPage, - }, nil) + }, nil, gitlab.WithContext(g.ctx)) if err != nil { return nil, err } @@ -305,11 +305,18 @@ func (g *GitlabDownloader) GetReleases() ([]*base.Release, error) { // GetAsset returns an asset func (g *GitlabDownloader) GetAsset(tag string, id int64) (io.ReadCloser, error) { - link, _, err := g.client.ReleaseLinks.GetReleaseLink(g.repoID, tag, int(id)) + link, _, err := g.client.ReleaseLinks.GetReleaseLink(g.repoID, tag, int(id), gitlab.WithContext(g.ctx)) if err != nil { return nil, err } - resp, err := http.Get(link.URL) + + req, err := http.NewRequest("GET", link.URL, nil) + if err != nil { + return nil, err + } + req = req.WithContext(g.ctx) + + resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } @@ -336,7 +343,7 @@ func (g *GitlabDownloader) GetIssues(page, perPage int) ([]*base.Issue, bool, er var allIssues = make([]*base.Issue, 0, perPage) - issues, _, err := g.client.Issues.ListProjectIssues(g.repoID, opt, nil) + issues, _, err := g.client.Issues.ListProjectIssues(g.repoID, opt, nil, gitlab.WithContext(g.ctx)) if err != nil { return nil, false, fmt.Errorf("error while listing issues: %v", err) } @@ -393,14 +400,14 @@ func (g *GitlabDownloader) GetComments(issueNumber int64) ([]*base.Comment, erro comments, resp, err = g.client.Discussions.ListIssueDiscussions(g.repoID, int(realIssueNumber), &gitlab.ListIssueDiscussionsOptions{ Page: page, PerPage: 100, - }, nil) + }, nil, gitlab.WithContext(g.ctx)) } else { // If this is a PR, we need to figure out the Gitlab/original PR ID to be passed below realIssueNumber = issueNumber - g.issueCount comments, resp, err = g.client.Discussions.ListMergeRequestDiscussions(g.repoID, int(realIssueNumber), &gitlab.ListMergeRequestDiscussionsOptions{ Page: page, PerPage: 100, - }, nil) + }, nil, gitlab.WithContext(g.ctx)) } if err != nil { @@ -455,7 +462,7 @@ func (g *GitlabDownloader) GetPullRequests(page, perPage int) ([]*base.PullReque var allPRs = make([]*base.PullRequest, 0, perPage) - prs, _, err := g.client.MergeRequests.ListProjectMergeRequests(g.repoID, opt, nil) + prs, _, err := g.client.MergeRequests.ListProjectMergeRequests(g.repoID, opt, nil, gitlab.WithContext(g.ctx)) if err != nil { return nil, fmt.Errorf("error while listing merge requests: %v", err) } @@ -536,7 +543,7 @@ func (g *GitlabDownloader) GetPullRequests(page, perPage int) ([]*base.PullReque // GetReviews returns pull requests review func (g *GitlabDownloader) GetReviews(pullRequestNumber int64) ([]*base.Review, error) { - state, _, err := g.client.MergeRequestApprovals.GetApprovalState(g.repoID, int(pullRequestNumber)) + state, _, err := g.client.MergeRequestApprovals.GetApprovalState(g.repoID, int(pullRequestNumber), gitlab.WithContext(g.ctx)) if err != nil { return nil, err } diff --git a/modules/migrations/gitlab_test.go b/modules/migrations/gitlab_test.go index daf05f8e3a6..1862d67cc11 100644 --- a/modules/migrations/gitlab_test.go +++ b/modules/migrations/gitlab_test.go @@ -5,6 +5,7 @@ package migrations import ( + "context" "net/http" "os" "testing" @@ -27,7 +28,7 @@ func TestGitlabDownloadRepo(t *testing.T) { t.Skipf("Can't access test repo, skipping %s", t.Name()) } - downloader := NewGitlabDownloader("https://gitlab.com", "gitea/test_repo", "", "", gitlabPersonalAccessToken) + downloader := NewGitlabDownloader(context.Background(), "https://gitlab.com", "gitea/test_repo", "", "", gitlabPersonalAccessToken) if downloader == nil { t.Fatal("NewGitlabDownloader is nil") } diff --git a/modules/migrations/migrate.go b/modules/migrations/migrate.go index 7858dfc6850..8543a3fc099 100644 --- a/modules/migrations/migrate.go +++ b/modules/migrations/migrate.go @@ -37,7 +37,7 @@ func MigrateRepository(ctx context.Context, doer *models.User, ownerName string, for _, factory := range factories { if factory.GitServiceType() == opts.GitServiceType { - downloader, err = factory.New(opts) + downloader, err = factory.New(ctx, opts) if err != nil { return nil, err } @@ -60,11 +60,9 @@ func MigrateRepository(ctx context.Context, doer *models.User, ownerName string, uploader.gitServiceType = opts.GitServiceType if setting.Migrations.MaxAttempts > 1 { - downloader = base.NewRetryDownloader(downloader, setting.Migrations.MaxAttempts, setting.Migrations.RetryBackoff) + downloader = base.NewRetryDownloader(ctx, downloader, setting.Migrations.MaxAttempts, setting.Migrations.RetryBackoff) } - downloader.SetContext(ctx) - if err := migrateRepository(downloader, uploader, opts); err != nil { if err1 := uploader.Rollback(); err1 != nil { log.Error("rollback failed: %v", err1)