git » summer » commit 278a534

Extract walking logic to a common function

author Alberto Bertogli
2023-08-28 19:20:35 UTC
committer Alberto Bertogli
2023-08-28 19:20:35 UTC
parent 0099b9c24c6f828e4108256413e2a14e26de5328

Extract walking logic to a common function

All the main operations (update, generate, verify) have the same
structure: a per-file function, and then surrounding logic to walk the
roots.

Today, the latter is duplicated on each function.

This patch abstracts it away, so each per-file function is more self
contained, and making it easier to make changes to the walking logic
(which will happen in later patches).

summer.go +76 -204
walk.go +94 -0

diff --git a/summer.go b/summer.go
index 479e77f..620258e 100644
--- a/summer.go
+++ b/summer.go
@@ -9,7 +9,6 @@ import (
 	"os"
 	"path/filepath"
 	"regexp"
-	"syscall"
 
 	"golang.org/x/term"
 )
@@ -111,11 +110,11 @@ func main() {
 
 	switch op {
 	case "generate":
-		err = generate(roots)
+		err = walk(roots, generate)
 	case "verify":
-		err = verify(roots)
+		err = walk(roots, verify)
 	case "update":
-		err = update(roots)
+		err = walk(roots, update)
 	case "version":
 		PrintVersion()
 	default:
@@ -162,238 +161,111 @@ type ChecksumV1 struct {
 	ModTimeUsec int64
 }
 
-func openAndInfo(path string, d fs.DirEntry, err error, rootDev deviceID) (bool, *os.File, fs.FileInfo, error) {
-	// Excluded check must come first, because it can be use to skip
-	// directories that would otherwise cause errors.
-	if isExcluded(path) {
-		if d.IsDir() {
-			return false, nil, nil, fs.SkipDir
-		}
-		return false, nil, nil, nil
-	}
-
+func generate(fd *os.File, info fs.FileInfo, p *Progress) error {
+	hasAttr, err := options.db.Has(fd)
 	if err != nil {
-		return false, nil, nil, err
+		return err
 	}
-	if d.IsDir() || !d.Type().IsRegular() {
-		return false, nil, nil, nil
+	if hasAttr {
+		// Skip files that already have a checksum.
+		return nil
 	}
 
-	// It is important that we obtain fs.FileInfo at this point, before
-	// reading any of the file contents, because the file could be modified
-	// while we do so. See the comment on ChecksumV1.ModTimeUsec for more
-	// details.
-	info, err := d.Info()
+	h := crc32.New(crc32c)
+	_, err = io.Copy(h, fd)
 	if err != nil {
-		return true, nil, nil, err
+		return err
 	}
 
-	fd, err := os.Open(path)
-	if err != nil {
-		return true, nil, nil, err
+	csum := ChecksumV1{
+		CRC32C:      h.Sum32(),
+		ModTimeUsec: info.ModTime().UnixMicro(),
 	}
 
-	if options.oneFilesystem && rootDev != getDevice(info) {
-		fd.Close()
-		return false, nil, nil, fs.SkipDir
+	err = options.db.Write(fd, csum)
+	if err != nil {
+		return err
 	}
 
-	return true, fd, info, nil
-}
-
-type deviceID uint64
-
-func getDevice(info fs.FileInfo) deviceID {
-	return deviceID(info.Sys().(*syscall.Stat_t).Dev)
+	p.PrintNew(fd.Name(), csum)
+	return nil
 }
 
-func getDeviceForPath(path string) deviceID {
-	fi, err := os.Stat(path)
+func verify(fd *os.File, info fs.FileInfo, p *Progress) error {
+	hasAttr, err := options.db.Has(fd)
 	if err != nil {
-		// Doesn't matter, because we'll get an error during WalkDir.
-		return 0
+		return err
 	}
-	return getDevice(fi)
-}
-
-func generate(roots []string) error {
-	rootDev := deviceID(0)
-	p := NewProgress(options.isTTY)
-	defer p.Stop()
-
-	fn := func(path string, d fs.DirEntry, err error) error {
-		ok, fd, info, err := openAndInfo(path, d, err, rootDev)
-		if !ok || err != nil {
-			return err
-		}
-		defer fd.Close()
-
-		hasAttr, err := options.db.Has(fd)
-		if err != nil {
-			return err
-		}
-		if hasAttr {
-			// Skip files that already have a checksum.
-			return nil
-		}
-
-		h := crc32.New(crc32c)
-		_, err = io.Copy(h, fd)
-		if err != nil {
-			return err
-		}
-
-		csum := ChecksumV1{
-			CRC32C:      h.Sum32(),
-			ModTimeUsec: info.ModTime().UnixMicro(),
-		}
-
-		err = options.db.Write(fd, csum)
-		if err != nil {
-			return err
-		}
-
-		p.PrintNew(path, csum)
+	if !hasAttr {
+		p.PrintMissing(fd.Name(), nil)
 		return nil
 	}
 
-	for _, root := range roots {
-		rootDev = getDeviceForPath(root)
-		err := filepath.WalkDir(root, fn)
-		if err != nil {
-			return err
-		}
+	csumFromFile, err := options.db.Read(fd)
+	if err != nil {
+		return err
 	}
-	return nil
-}
-
-func verify(roots []string) error {
-	rootDev := deviceID(0)
-	p := NewProgress(options.isTTY)
-	defer p.Stop()
-
-	fn := func(path string, d fs.DirEntry, err error) error {
-		ok, fd, info, err := openAndInfo(path, d, err, rootDev)
-		if !ok || err != nil {
-			return err
-		}
-		defer fd.Close()
-
-		hasAttr, err := options.db.Has(fd)
-		if err != nil {
-			return err
-		}
-		if !hasAttr {
-			p.PrintMissing(path, nil)
-			return nil
-		}
-
-		csumFromFile, err := options.db.Read(fd)
-		if err != nil {
-			return err
-		}
 
-		h := crc32.New(crc32c)
-		_, err = io.Copy(h, fd)
-		if err != nil {
-			return err
-		}
-
-		csumComputed := ChecksumV1{
-			CRC32C:      h.Sum32(),
-			ModTimeUsec: info.ModTime().UnixMicro(),
-		}
-
-		if csumFromFile.ModTimeUsec != csumComputed.ModTimeUsec {
-			p.PrintModified(path, csumFromFile, csumComputed)
-		} else if csumFromFile.CRC32C != csumComputed.CRC32C {
-			p.PrintCorrupted(path, csumFromFile, csumComputed)
-		} else {
-			p.PrintMatched(path, csumComputed)
-		}
-
-		return nil
+	h := crc32.New(crc32c)
+	_, err = io.Copy(h, fd)
+	if err != nil {
+		return err
 	}
 
-	var err error
-	for _, root := range roots {
-		rootDev = getDeviceForPath(root)
-		err = filepath.WalkDir(root, fn)
-		if err != nil {
-			break
-		}
+	csumComputed := ChecksumV1{
+		CRC32C:      h.Sum32(),
+		ModTimeUsec: info.ModTime().UnixMicro(),
 	}
 
-	if p.corrupted > 0 && err == nil {
-		err = fmt.Errorf("detected %d corrupted files", p.corrupted)
+	if csumFromFile.ModTimeUsec != csumComputed.ModTimeUsec {
+		p.PrintModified(fd.Name(), csumFromFile, csumComputed)
+	} else if csumFromFile.CRC32C != csumComputed.CRC32C {
+		p.PrintCorrupted(fd.Name(), csumFromFile, csumComputed)
+	} else {
+		p.PrintMatched(fd.Name(), csumComputed)
 	}
-	return err
-}
-
-func update(roots []string) error {
-	rootDev := deviceID(0)
-	p := NewProgress(options.isTTY)
-	defer p.Stop()
-
-	fn := func(path string, d fs.DirEntry, err error) error {
-		ok, fd, info, err := openAndInfo(path, d, err, rootDev)
-		if !ok || err != nil {
-			return err
-		}
-		defer fd.Close()
-
-		// Compute checksum from the current state.
-		h := crc32.New(crc32c)
-		_, err = io.Copy(h, fd)
-		if err != nil {
-			return err
-		}
-
-		csumComputed := ChecksumV1{
-			CRC32C:      h.Sum32(),
-			ModTimeUsec: info.ModTime().UnixMicro(),
-		}
 
-		// Read the saved checksum (if any).
-		hasAttr, err := options.db.Has(fd)
-		if err != nil {
-			return err
-		}
-		if !hasAttr {
-			// Attribute is missing. Expected for newly created files.
-			p.PrintMissing(path, &csumComputed)
-			return options.db.Write(fd, csumComputed)
-		}
+	return nil
+}
 
-		csumFromFile, err := options.db.Read(fd)
-		if err != nil {
-			return err
-		}
+func update(fd *os.File, info fs.FileInfo, p *Progress) error {
+	// Compute checksum from the current state.
+	h := crc32.New(crc32c)
+	_, err := io.Copy(h, fd)
+	if err != nil {
+		return err
+	}
 
-		if csumFromFile.ModTimeUsec != csumComputed.ModTimeUsec {
-			// File modified. Expected for updated files.
-			p.PrintModified(path, csumFromFile, csumComputed)
-			return options.db.Write(fd, csumComputed)
-		} else if csumFromFile.CRC32C != csumComputed.CRC32C {
-			p.PrintCorrupted(path, csumFromFile, csumComputed)
-		} else {
-			p.PrintMatched(path, csumComputed)
-		}
+	csumComputed := ChecksumV1{
+		CRC32C:      h.Sum32(),
+		ModTimeUsec: info.ModTime().UnixMicro(),
+	}
 
-		return nil
+	// Read the saved checksum (if any).
+	hasAttr, err := options.db.Has(fd)
+	if err != nil {
+		return err
+	}
+	if !hasAttr {
+		// Attribute is missing. Expected for newly created files.
+		p.PrintMissing(fd.Name(), &csumComputed)
+		return options.db.Write(fd, csumComputed)
 	}
 
-	var err error
-	for _, root := range roots {
-		rootDev = getDeviceForPath(root)
-		err = filepath.WalkDir(root, fn)
-		if err != nil {
-			break
-		}
+	csumFromFile, err := options.db.Read(fd)
+	if err != nil {
+		return err
 	}
 
-	if p.corrupted > 0 && err == nil {
-		err = fmt.Errorf("detected %d corrupted files", p.corrupted)
+	if csumFromFile.ModTimeUsec != csumComputed.ModTimeUsec {
+		// File modified. Expected for updated files.
+		p.PrintModified(fd.Name(), csumFromFile, csumComputed)
+		return options.db.Write(fd, csumComputed)
+	} else if csumFromFile.CRC32C != csumComputed.CRC32C {
+		p.PrintCorrupted(fd.Name(), csumFromFile, csumComputed)
+	} else {
+		p.PrintMatched(fd.Name(), csumComputed)
 	}
-	return err
+
+	return nil
 }
diff --git a/walk.go b/walk.go
new file mode 100644
index 0000000..403a390
--- /dev/null
+++ b/walk.go
@@ -0,0 +1,94 @@
+package main
+
+import (
+	"fmt"
+	"io/fs"
+	"os"
+	"path/filepath"
+	"syscall"
+)
+
+func openAndInfo(path string, d fs.DirEntry, err error, rootDev deviceID) (bool, *os.File, fs.FileInfo, error) {
+	// Excluded check must come first, because it can be use to skip
+	// directories that would otherwise cause errors.
+	if isExcluded(path) {
+		if d.IsDir() {
+			return false, nil, nil, fs.SkipDir
+		}
+		return false, nil, nil, nil
+	}
+
+	if err != nil {
+		return false, nil, nil, err
+	}
+	if d.IsDir() || !d.Type().IsRegular() {
+		return false, nil, nil, nil
+	}
+
+	// It is important that we obtain fs.FileInfo at this point, before
+	// reading any of the file contents, because the file could be modified
+	// while we do so. See the comment on ChecksumV1.ModTimeUsec for more
+	// details.
+	info, err := d.Info()
+	if err != nil {
+		return true, nil, nil, err
+	}
+
+	fd, err := os.Open(path)
+	if err != nil {
+		return true, nil, nil, err
+	}
+
+	if options.oneFilesystem && rootDev != getDevice(info) {
+		fd.Close()
+		return false, nil, nil, fs.SkipDir
+	}
+
+	return true, fd, info, nil
+}
+
+type deviceID uint64
+
+func getDevice(info fs.FileInfo) deviceID {
+	return deviceID(info.Sys().(*syscall.Stat_t).Dev)
+}
+
+func getDeviceForPath(path string) deviceID {
+	fi, err := os.Stat(path)
+	if err != nil {
+		// Doesn't matter, because we'll get an error during WalkDir.
+		return 0
+	}
+	return getDevice(fi)
+}
+
+type walkFn func(fd *os.File, info fs.FileInfo, p *Progress) error
+
+func walk(roots []string, fn walkFn) error {
+	rootDev := deviceID(0)
+	p := NewProgress(options.isTTY)
+	defer p.Stop()
+
+	wfn := func(path string, d fs.DirEntry, err error) error {
+		ok, fd, info, err := openAndInfo(path, d, err, rootDev)
+		if !ok || err != nil {
+			return err
+		}
+		defer fd.Close()
+		return fn(fd, info, p)
+	}
+
+	var err error
+	for _, root := range roots {
+		rootDev = getDeviceForPath(root)
+		err = filepath.WalkDir(root, wfn)
+		if err != nil {
+			break
+		}
+	}
+
+	if p.corrupted > 0 && err == nil {
+		err = fmt.Errorf("detected %d corrupted files", p.corrupted)
+	}
+	return err
+}