From da80e90ac8b12e89be440bc45177230706c43920 Mon Sep 17 00:00:00 2001 From: 6543 <6543@obermui.de> Date: Sat, 6 Mar 2021 05:07:03 +0100 Subject: [PATCH] Fix race in local storage (#14888) (#14901) LocalStorage should only put completed files in position Signed-off-by: Andrew Thornton Co-authored-by: zeripath Co-authored-by: techknowlogick --- modules/storage/local.go | 50 +++++++++++++++++++++++++++++++--------- 1 file changed, 39 insertions(+), 11 deletions(-) diff --git a/modules/storage/local.go b/modules/storage/local.go index f7ffb2ddc15..84bf0c6627b 100644 --- a/modules/storage/local.go +++ b/modules/storage/local.go @@ -7,6 +7,7 @@ package storage import ( "context" "io" + "io/ioutil" "net/url" "os" "path/filepath" @@ -24,13 +25,15 @@ const LocalStorageType Type = "local" // LocalStorageConfig represents the configuration for a local storage type LocalStorageConfig struct { - Path string `ini:"PATH"` + Path string `ini:"PATH"` + TemporaryPath string `ini:"TEMPORARY_PATH"` } // LocalStorage represents a local files storage type LocalStorage struct { - ctx context.Context - dir string + ctx context.Context + dir string + tmpdir string } // NewLocalStorage returns a local files @@ -45,9 +48,14 @@ func NewLocalStorage(ctx context.Context, cfg interface{}) (ObjectStorage, error return nil, err } + if config.TemporaryPath == "" { + config.TemporaryPath = config.Path + "/tmp" + } + return &LocalStorage{ - ctx: ctx, - dir: config.Path, + ctx: ctx, + dir: config.Path, + tmpdir: config.TemporaryPath, }, nil } @@ -63,17 +71,37 @@ func (l *LocalStorage) Save(path string, r io.Reader) (int64, error) { return 0, err } - // always override - if err := util.Remove(p); err != nil { + // Create a temporary file to save to + if err := os.MkdirAll(l.tmpdir, os.ModePerm); err != nil { return 0, err } - - f, err := os.Create(p) + tmp, err := ioutil.TempFile(l.tmpdir, "upload-*") if err != nil { return 0, err } - defer f.Close() - return io.Copy(f, r) + tmpRemoved := false + defer func() { + if !tmpRemoved { + _ = util.Remove(tmp.Name()) + } + }() + + n, err := io.Copy(tmp, r) + if err != nil { + return 0, err + } + + if err := tmp.Close(); err != nil { + return 0, err + } + + if err := os.Rename(tmp.Name(), p); err != nil { + return 0, err + } + + tmpRemoved = true + + return n, nil } // Stat returns the info of the file