Use context.Context for timeout

This commit is contained in:
Locria Cyber 2023-10-03 10:56:35 +00:00
parent 61d065c35d
commit 0cb3224558
Signed by: iacore
GPG key ID: F8C16E5157A63006
5 changed files with 28 additions and 25 deletions

View file

@ -1,11 +1,13 @@
package main
import (
"context"
"io/fs"
"log"
"net"
"net/http"
"os"
"time"
"embed"
@ -54,7 +56,10 @@ func rootPage(w http.ResponseWriter, r *http.Request) {
}
// get service status'
data := scanner.TestMultipleServices(config.Service)
timeout := time.Second
ctx, cancel := context.WithTimeout(context.Background(), timeout)
data := scanner.TestMultipleServices(ctx, config.Service)
cancel()
// don't uncomment this in production. verbose
// log.Printf("%v", data)

View file

@ -2,17 +2,20 @@ package scanner
import "net/http"
func Request(url string) (bool, string) {
req, err := http.Get(url)
func CheckHTTP200(ctx Context, url string) (bool, string) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return false, err.Error()
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
// `req` could be `nil` so we need this check
return false, err.Error()
}
// if code is 200-299
if req.StatusCode >= 200 && req.StatusCode <= 299 {
return true, req.Status
if resp.StatusCode >= 200 && resp.StatusCode <= 299 {
return true, resp.Status
}
return false, req.Status
return false, resp.Status
}

View file

@ -2,15 +2,9 @@ package scanner
import (
"github.com/prometheus-community/pro-bing"
"time"
"context"
)
func CheckICMPUp(host string) (bool, string) {
timeout := time.Second
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
func CheckICMPUp(ctx Context, host string) (bool, string) {
pinger, err := probing.NewPinger(host)
if err != nil {
return false, err.Error()

View file

@ -1,11 +1,13 @@
package scanner
import (
"context"
"log"
"net/url"
"sync"
)
type Context = context.Context
type ServiceConfig struct {
Description string
@ -28,7 +30,7 @@ type ServiceStatus struct {
Status string
}
func TestService(service ServiceConfig) ServiceStatus {
func TestService(service ServiceConfig, ctx context.Context) ServiceStatus {
var serviceResult bool
var status string
@ -36,20 +38,20 @@ func TestService(service ServiceConfig) ServiceStatus {
switch url.Scheme {
case "http", "https":
{
serviceResult, status = Request(url.String())
serviceResult, status = CheckHTTP200(ctx, url.String())
break
}
case "tcp":
{
// go includes port in host
serviceResult, status = CheckTCPOpen(url.Host)
serviceResult, status = CheckTCPOpen(ctx, url.Host)
break
}
case "icmp":
{
// go includes port in host
serviceResult, status = CheckICMPUp(url.Host)
serviceResult, status = CheckICMPUp(ctx, url.Host)
break
}
default:
@ -66,7 +68,7 @@ func TestService(service ServiceConfig) ServiceStatus {
}
}
func TestMultipleServices(services []ServiceConfig) []ServiceStatus {
func TestMultipleServices(ctx Context, services []ServiceConfig) []ServiceStatus {
var wg sync.WaitGroup
var data []ServiceStatus = make([]ServiceStatus, len(services))
@ -75,7 +77,7 @@ func TestMultipleServices(services []ServiceConfig) []ServiceStatus {
wg.Add(1)
go func(i int, config ServiceConfig) {
defer wg.Done()
data[i] = TestService(config)
data[i] = TestService(config, ctx)
}(i, v)
}
wg.Wait()

View file

@ -2,15 +2,14 @@ package scanner
import (
"net"
"time"
)
func CheckTCPOpen(host string) (bool, string) {
func CheckTCPOpen(ctx Context, host string) (bool, string) {
// Just check if the connection is successful
timeout := time.Second
conn, err := net.DialTimeout("tcp", host, timeout)
d := net.Dialer{}
conn, err := d.DialContext(ctx, "tcp", host)
if err != nil || conn == nil {
if err != nil {
// If timeout or anything else
return false, err.Error()
}