1010use clap:: Parser ;
1111use gettextrs:: { bind_textdomain_codeset, setlocale, textdomain, LocaleCategory } ;
1212use plib:: PROJECT_NAME ;
13- use std:: io:: { self , Read , Write } ;
13+ use std:: error:: Error ;
14+ use std:: io:: { self , Read , StdoutLock , Write } ;
1415use std:: path:: PathBuf ;
1516
16- /// head - copy the first part of files
17+ const N_C_GROUP : & str = "N_C_GROUP" ;
18+
19+ /// head - copy the first part of files.
20+ /// If neither -n nor -c are specified, copies the first 10 lines of each file (-n 10).
1721#[ derive( Parser ) ]
1822#[ command( version, about) ]
1923struct Args {
20- /// The first <N> lines of each input file shall be copied to standard output.
21- #[ arg( short, default_value_t = 10 , value_parser = clap:: value_parser!( u64 ) . range( 1 ..) ) ]
22- n : u64 ,
24+ /// The first <N> lines of each input file shall be copied to standard output (mutually exclusive with -c)
25+ #[ arg( long = "lines" , short, value_parser = clap:: value_parser!( usize ) , group = N_C_GROUP ) ]
26+ n : Option < usize > ,
27+
28+ // Note: -c was added to POSIX in POSIX.1-2024, but has been supported on most platforms since the late 1990s
29+ // https://pubs.opengroup.org/onlinepubs/9799919799/utilities/head.html
30+ //
31+ /// The first <N> bytes of each input file shall be copied to standard output (mutually exclusive with -n)
32+ #[ arg( long = "bytes" , short = 'c' , value_parser = clap:: value_parser!( usize ) , group = N_C_GROUP ) ]
33+ bytes_to_copy : Option < usize > ,
2334
2435 /// Files to read as input.
2536 files : Vec < PathBuf > ,
2637}
2738
28- fn head_file ( args : & Args , pathname : & PathBuf , first : bool , want_header : bool ) -> io:: Result < ( ) > {
39+ enum CountType {
40+ Bytes ( usize ) ,
41+ Lines ( usize ) ,
42+ }
43+
44+ fn head_file (
45+ count_type : & CountType ,
46+ pathname : & PathBuf ,
47+ first : bool ,
48+ want_header : bool ,
49+ stdout_lock : & mut StdoutLock ,
50+ ) -> Result < ( ) , Box < dyn Error > > {
51+ const BUFFER_SIZE : usize = plib:: BUFSZ ;
52+
2953 // print file header
3054 if want_header {
3155 if first {
32- println ! ( "==> {} <==\n " , pathname. display( ) ) ;
56+ writeln ! ( stdout_lock , "==> {} <==" , pathname. display( ) ) ? ;
3357 } else {
34- println ! ( "\n ==> {} <==\n " , pathname. display( ) ) ;
58+ writeln ! ( stdout_lock , "\n ==> {} <==" , pathname. display( ) ) ? ;
3559 }
3660 }
3761
3862 // open file, or stdin
3963 let mut file = plib:: io:: input_stream ( pathname, false ) ?;
4064
41- let mut raw_buffer = [ 0 ; plib:: BUFSZ ] ;
42- let mut nl = 0 ;
65+ let mut raw_buffer = [ 0_u8 ; BUFFER_SIZE ] ;
4366
44- loop {
45- // read a chunk of file data
46- let n_read = file. read ( & mut raw_buffer[ ..] ) ?;
47- if n_read == 0 {
48- break ;
49- }
67+ match * count_type {
68+ CountType :: Bytes ( bytes_to_copy) => {
69+ let mut bytes_remaining = bytes_to_copy;
5070
51- // slice of buffer containing file data
52- let buf = & raw_buffer[ 0 ..n_read] ;
53- let mut pos = 0 ;
71+ loop {
72+ let number_of_bytes_read = {
73+ // Do not read more bytes than necessary
74+ let read_up_to_n_bytes = BUFFER_SIZE . min ( bytes_remaining) ;
5475
55- // count newlines
56- for chv in buf {
57- // LF character encountered
58- if * chv == 10 {
59- nl += 1 ;
60- }
76+ file. read ( & mut raw_buffer[ ..read_up_to_n_bytes] ) ?
77+ } ;
78+
79+ if number_of_bytes_read == 0_usize {
80+ // Reached EOF
81+ break ;
82+ }
83+
84+ let bytes_to_write = & raw_buffer[ ..number_of_bytes_read] ;
6185
62- pos += 1 ;
86+ stdout_lock . write_all ( bytes_to_write ) ? ;
6387
64- // if user-specified limit reached, stop
65- if nl >= args. n {
66- break ;
88+ bytes_remaining -= number_of_bytes_read;
89+
90+ if bytes_remaining == 0_usize {
91+ break ;
92+ }
6793 }
6894 }
95+ CountType :: Lines ( n) => {
96+ let mut nl = 0_usize ;
97+
98+ loop {
99+ // read a chunk of file data
100+ let n_read = file. read ( & mut raw_buffer) ?;
101+
102+ if n_read == 0_usize {
103+ // Reached EOF
104+ break ;
105+ }
106+
107+ // slice of buffer containing file data
108+ let buf = & raw_buffer[ ..n_read] ;
109+
110+ let mut position = 0_usize ;
111+
112+ // count newlines
113+ for & byte in buf {
114+ position += 1 ;
115+
116+ // LF character encountered
117+ if byte == b'\n' {
118+ nl += 1 ;
119+
120+ // if user-specified limit reached, stop
121+ if nl >= n {
122+ break ;
123+ }
124+ }
125+ }
69126
70- // output full or partial buffer
71- let final_buf = & raw_buffer[ 0 ..pos] ;
72- io:: stdout ( ) . write_all ( final_buf) ?;
127+ // output full or partial buffer
128+ let bytes_to_write = & raw_buffer[ ..position] ;
73129
74- // if user-specified limit reached, stop
75- if nl >= args. n {
76- break ;
130+ stdout_lock. write_all ( bytes_to_write) ?;
131+
132+ // if user-specified limit reached, stop
133+ if nl >= n {
134+ break ;
135+ }
136+ }
77137 }
78138 }
79139
@@ -84,21 +144,60 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
84144 // parse command line arguments
85145 let mut args = Args :: parse ( ) ;
86146
147+ // bsdutils (FreeBSD) enforces n > 0 (and c > 0)
148+ // BusyBox, coreutils' uutils, GNU Core Utilities, and toybox do not (and just print nothing)
149+ // POSIX says:
150+ // "The application shall ensure that the number option-argument is a positive decimal integer."
151+ let count_type = match ( args. n , args. bytes_to_copy ) {
152+ ( None , None ) => {
153+ // If no arguments are provided, the default is 10 lines
154+ CountType :: Lines ( 10_usize )
155+ }
156+ ( Some ( n) , None ) => {
157+ if n == 0_usize {
158+ eprintln ! ( "head: when a value for -n is provided, it must be greater than 0" ) ;
159+
160+ std:: process:: exit ( 1_i32 ) ;
161+ }
162+
163+ CountType :: Lines ( n)
164+ }
165+ ( None , Some ( bytes_to_copy) ) => {
166+ if bytes_to_copy == 0_usize {
167+ eprintln ! ( "head: when a value for -c is provided, it must be greater than 0" ) ;
168+
169+ std:: process:: exit ( 1_i32 ) ;
170+ }
171+
172+ CountType :: Bytes ( bytes_to_copy)
173+ }
174+
175+ ( Some ( _) , Some ( _) ) => {
176+ // Will be caught by clap
177+ unreachable ! ( ) ;
178+ }
179+ } ;
180+
87181 setlocale ( LocaleCategory :: LcAll , "" ) ;
88182 textdomain ( PROJECT_NAME ) ?;
89183 bind_textdomain_codeset ( PROJECT_NAME , "UTF-8" ) ?;
90184
185+ let files = & mut args. files ;
186+
91187 // if no files, read from stdin
92- if args . files . is_empty ( ) {
93- args . files . push ( PathBuf :: new ( ) ) ;
188+ if files. is_empty ( ) {
189+ files. push ( PathBuf :: new ( ) ) ;
94190 }
95191
192+ let want_header = files. len ( ) > 1 ;
193+
96194 let mut exit_code = 0 ;
97- let want_header = args. files . len ( ) > 1 ;
98195 let mut first = true ;
99196
100- for filename in & args. files {
101- if let Err ( e) = head_file ( & args, filename, first, want_header) {
197+ let mut stdout_lock = io:: stdout ( ) . lock ( ) ;
198+
199+ for filename in files {
200+ if let Err ( e) = head_file ( & count_type, filename, first, want_header, & mut stdout_lock) {
102201 exit_code = 1 ;
103202 eprintln ! ( "{}: {}" , filename. display( ) , e) ;
104203 }
0 commit comments