diff --git a/clamrest.go b/clamrest.go index 43aa1b3..7dd17e7 100644 --- a/clamrest.go +++ b/clamrest.go @@ -3,7 +3,6 @@ package main import ( "encoding/json" "fmt" - "io" "io/ioutil" "log" "net/http" @@ -21,24 +20,33 @@ func init() { log.SetOutput(ioutil.Discard) } +type Error struct { + Error string `json:"Error"` +} + +func writeError(w http.ResponseWriter, statusCode int, err string) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(statusCode) + + errJson, _ := json.Marshal(&Error{err}) + if errJson != nil { + fmt.Fprint(w, string(errJson)) + } +} + func home(w http.ResponseWriter, r *http.Request) { c := clamd.NewClamd(opts["CLAMD_PORT"]) response, err := c.Stats() if err != nil { - errJson, eErr := json.Marshal(err) - if eErr != nil { - fmt.Println(eErr) - return - } - fmt.Fprint(w, string(errJson)) + writeError(w, http.StatusInternalServerError, "Could not get stats: "+err.Error()) return } resJson, eRes := json.Marshal(response) if eRes != nil { - fmt.Println(eRes) + writeError(w, http.StatusInternalServerError, "Could not marshal JSON: "+eRes.Error()) return } w.Header().Set("Content-Type", "application/json; charset=utf-8") @@ -58,12 +66,7 @@ func scanPathHandler(w http.ResponseWriter, r *http.Request) { response, err := c.AllMatchScanFile(path) if err != nil { - errJson, eErr := json.Marshal(err) - if eErr != nil { - fmt.Println(eErr) - return - } - fmt.Fprint(w, string(errJson)) + writeError(w, http.StatusInternalServerError, "Could not scan file: "+err.Error()) return } @@ -92,45 +95,56 @@ func scanHandler(w http.ResponseWriter, r *http.Request) { reader, err := r.MultipartReader() if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + writeError(w, http.StatusInternalServerError, "Could not initialize reader: "+err.Error()) return } - //copy each part to destination. - for { - part, err := reader.NextPart() - if err == io.EOF { - break - } - - //if part.FileName() is empty, skip this iteration. - if part.FileName() == "" { - continue - } - - fmt.Printf(time.Now().Format(time.RFC3339) + " Started scanning: " + part.FileName() + "\n") - var abort chan bool - response, err := c.ScanStream(part, abort) - for s := range response { - w.Header().Set("Content-Type", "application/json; charset=utf-8") - respJson := fmt.Sprintf("{ Status: \"%s\", Description: \"%s\" }", s.Status, s.Description) - switch s.Status { - case clamd.RES_OK: - w.WriteHeader(http.StatusOK) - case clamd.RES_FOUND: - w.WriteHeader(http.StatusNotAcceptable) - case clamd.RES_ERROR: - w.WriteHeader(http.StatusBadRequest) - case clamd.RES_PARSE_ERROR: - w.WriteHeader(http.StatusPreconditionFailed) - default: - w.WriteHeader(http.StatusNotImplemented) - } - fmt.Fprint(w, respJson) - fmt.Printf(time.Now().Format(time.RFC3339)+" Scan result for: %v, %v\n", part.FileName(), s) - } - fmt.Printf(time.Now().Format(time.RFC3339) + " Finished scanning: " + part.FileName() + "\n") + part, err := reader.NextPart() + if err != nil { + writeError(w, http.StatusInternalServerError, "Could not read file: "+err.Error()) + return } + + //if part.FileName() is empty, skip this iteration. + if part.FileName() == "" { + writeError(w, http.StatusBadRequest, "Filename is empty") + return + } + + fmt.Printf(time.Now().Format(time.RFC3339) + " Started scanning: " + part.FileName() + "\n") + var abort chan bool + response, err := c.ScanStream(part, abort) + if err != nil { + writeError(w, http.StatusInternalServerError, "Could not scan file: "+err.Error()) + return + } + + for s := range response { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + + respJson, err := json.Marshal(&s) + if err != nil { + writeError(w, http.StatusInternalServerError, "Could not marshal JSON: "+err.Error()) + return + } + + switch s.Status { + case clamd.RES_OK: + w.WriteHeader(http.StatusOK) + case clamd.RES_FOUND: + w.WriteHeader(http.StatusNotAcceptable) + case clamd.RES_ERROR: + w.WriteHeader(http.StatusBadRequest) + case clamd.RES_PARSE_ERROR: + w.WriteHeader(http.StatusPreconditionFailed) + default: + w.WriteHeader(http.StatusNotImplemented) + } + + fmt.Fprint(w, string(respJson)) + fmt.Printf(time.Now().Format(time.RFC3339)+" Scan result for: %v, %v\n", part.FileName(), s) + } + fmt.Printf(time.Now().Format(time.RFC3339) + " Finished scanning: " + part.FileName() + "\n") default: w.WriteHeader(http.StatusMethodNotAllowed) }