// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Package iotest implements Readers and Writers useful mainly for testing.
package iotest import ( ) // OneByteReader returns a Reader that implements // each non-empty Read by reading one byte from r. func ( io.Reader) io.Reader { return &oneByteReader{} } type oneByteReader struct { r io.Reader } func ( *oneByteReader) ( []byte) (int, error) { if len() == 0 { return 0, nil } return .r.Read([0:1]) } // HalfReader returns a Reader that implements Read // by reading half as many requested bytes from r. func ( io.Reader) io.Reader { return &halfReader{} } type halfReader struct { r io.Reader } func ( *halfReader) ( []byte) (int, error) { return .r.Read([0 : (len()+1)/2]) } // DataErrReader changes the way errors are handled by a Reader. Normally, a // Reader returns an error (typically EOF) from the first Read call after the // last piece of data is read. DataErrReader wraps a Reader and changes its // behavior so the final error is returned along with the final data, instead // of in the first call after the final data. func ( io.Reader) io.Reader { return &dataErrReader{, nil, make([]byte, 1024)} } type dataErrReader struct { r io.Reader unread []byte data []byte } func ( *dataErrReader) ( []byte) ( int, error) { // loop because first call needs two reads: // one to get data and a second to look for an error. for { if len(.unread) == 0 { , := .r.Read(.data) .unread = .data[0:] = } if > 0 || != nil { break } = copy(, .unread) .unread = .unread[:] } return } // ErrTimeout is a fake timeout error. var ErrTimeout = errors.New("timeout") // TimeoutReader returns [ErrTimeout] on the second read // with no data. Subsequent calls to read succeed. func ( io.Reader) io.Reader { return &timeoutReader{, 0} } type timeoutReader struct { r io.Reader count int } func ( *timeoutReader) ( []byte) (int, error) { .count++ if .count == 2 { return 0, ErrTimeout } return .r.Read() } // ErrReader returns an [io.Reader] that returns 0, err from all Read calls. func ( error) io.Reader { return &errReader{err: } } type errReader struct { err error } func ( *errReader) ( []byte) (int, error) { return 0, .err } type smallByteReader struct { r io.Reader off int n int } func ( *smallByteReader) ( []byte) (int, error) { if len() == 0 { return 0, nil } .n = .n%3 + 1 := .n if > len() { = len() } , := .r.Read([0:]) if != nil && != io.EOF { = fmt.Errorf("Read(%d bytes at offset %d): %v", , .off, ) } .off += return , } // TestReader tests that reading from r returns the expected file content. // It does reads of different sizes, until EOF. // If r implements [io.ReaderAt] or [io.Seeker], TestReader also checks // that those operations behave as they should. // // If TestReader finds any misbehaviors, it returns an error reporting them. // The error text may span multiple lines. func ( io.Reader, []byte) error { if len() > 0 { , := .Read(nil) if != 0 || != nil { return fmt.Errorf("Read(0) = %d, %v, want 0, nil", , ) } } , := io.ReadAll(&smallByteReader{r: }) if != nil { return } if !bytes.Equal(, ) { return fmt.Errorf("ReadAll(small amounts) = %q\n\twant %q", , ) } , := .Read(make([]byte, 10)) if != 0 || != io.EOF { return fmt.Errorf("Read(10) at EOF = %v, %v, want 0, EOF", , ) } if , := .(io.ReadSeeker); { // Seek(0, 1) should report the current file position (EOF). if , := .Seek(0, 1); != int64(len()) || != nil { return fmt.Errorf("Seek(0, 1) from EOF = %d, %v, want %d, nil", , , len()) } // Seek backward partway through file, in two steps. // If middle == 0, len(content) == 0, can't use the -1 and +1 seeks. := len() - len()/3 if > 0 { if , := .Seek(-1, 1); != int64(len()-1) || != nil { return fmt.Errorf("Seek(-1, 1) from EOF = %d, %v, want %d, nil", -, , len()-1) } if , := .Seek(int64(-len()/3), 1); != int64(-1) || != nil { return fmt.Errorf("Seek(%d, 1) from %d = %d, %v, want %d, nil", -len()/3, len()-1, , , -1) } if , := .Seek(+1, 1); != int64() || != nil { return fmt.Errorf("Seek(+1, 1) from %d = %d, %v, want %d, nil", -1, , , ) } } // Seek(0, 1) should report the current file position (middle). if , := .Seek(0, 1); != int64() || != nil { return fmt.Errorf("Seek(0, 1) from %d = %d, %v, want %d, nil", , , , ) } // Reading forward should return the last part of the file. , := io.ReadAll(&smallByteReader{r: }) if != nil { return fmt.Errorf("ReadAll from offset %d: %v", , ) } if !bytes.Equal(, [:]) { return fmt.Errorf("ReadAll from offset %d = %q\n\twant %q", , , [:]) } // Seek relative to end of file, but start elsewhere. if , := .Seek(int64(/2), 0); != int64(/2) || != nil { return fmt.Errorf("Seek(%d, 0) from EOF = %d, %v, want %d, nil", /2, , , /2) } if , := .Seek(int64(-len()/3), 2); != int64() || != nil { return fmt.Errorf("Seek(%d, 2) from %d = %d, %v, want %d, nil", -len()/3, /2, , , ) } // Reading forward should return the last part of the file (again). , = io.ReadAll(&smallByteReader{r: }) if != nil { return fmt.Errorf("ReadAll from offset %d: %v", , ) } if !bytes.Equal(, [:]) { return fmt.Errorf("ReadAll from offset %d = %q\n\twant %q", , , [:]) } // Absolute seek & read forward. if , := .Seek(int64(/2), 0); != int64(/2) || != nil { return fmt.Errorf("Seek(%d, 0) from EOF = %d, %v, want %d, nil", /2, , , /2) } , = io.ReadAll() if != nil { return fmt.Errorf("ReadAll from offset %d: %v", /2, ) } if !bytes.Equal(, [/2:]) { return fmt.Errorf("ReadAll from offset %d = %q\n\twant %q", /2, , [/2:]) } } if , := .(io.ReaderAt); { := make([]byte, len(), len()+1) for := range { [] = 0xfe } , := .ReadAt(, 0) if != len() || != nil && != io.EOF { return fmt.Errorf("ReadAt(%d, 0) = %v, %v, want %d, nil or EOF", len(), , , len()) } if !bytes.Equal(, ) { return fmt.Errorf("ReadAt(%d, 0) = %q\n\twant %q", len(), , ) } , = .ReadAt([:1], int64(len())) if != 0 || != io.EOF { return fmt.Errorf("ReadAt(1, %d) = %v, %v, want 0, EOF", len(), , ) } for := range { [] = 0xfe } , = .ReadAt([:cap()], 0) if != len() || != io.EOF { return fmt.Errorf("ReadAt(%d, 0) = %v, %v, want %d, EOF", cap(), , , len()) } if !bytes.Equal(, ) { return fmt.Errorf("ReadAt(%d, 0) = %q\n\twant %q", len(), , ) } for := range { [] = 0xfe } for := range { , = .ReadAt([:+1], int64()) if != 1 || != nil && ( != len()-1 || != io.EOF) { := "nil" if == len()-1 { = "nil or EOF" } return fmt.Errorf("ReadAt(1, %d) = %v, %v, want 1, %s", , , , ) } if [] != [] { return fmt.Errorf("ReadAt(1, %d) = %q want %q", , [:+1], [:+1]) } } } return nil }