package main

import (
	"bufio"
	"bytes"
	"crypto/sha1"
	"encoding/csv"
	"encoding/json"
	"flag"
	"fmt"
	"io"
	"log"
	"math/rand"
	"net"
	"net/http"
	"os"
	"strings"
	"time"

	"github.com/BurntSushi/toml"
	"github.com/google/gopacket"
	"github.com/google/gopacket/layers"
	"github.com/google/gopacket/pcap"
	"github.com/nats-io/nats"
	"github.com/nats-io/nats/encoders/protobuf"

	"git.scraperwall.com/scw/ajp13"
	"git.scraperwall.com/scw/data"
	"git.scraperwall.com/scw/ip"
)

var (
	doLiveCapture         = flag.Bool("live", false, "Capture data in real time from a given interface")
	iface                 = flag.String("interface", "eth0", "Interface to get packets from")
	snapshotLen           = flag.Int("snapshot-len", 8192, "Snapshot Length in Bytes")
	filter                = flag.String("filter", "tcp", "PCAP filter expression")
	promiscuous           = flag.Bool("promiscuous", false, "Switch interface into promiscuous mode?")
	natsURL               = flag.String("nats-url", "nats://127.0.0.1:4222", "The URL of the NATS server")
	natsUser              = flag.String("nats-user", "", "The user for NATS authentication")
	natsPassword          = flag.String("nats-password", "", "The password for NATS authentication")
	natsQueue             = flag.String("nats-queue", "requests", "The NATS queue name")
	natsCA                = flag.String("nats-ca", "", "CA chain for NATS TLS")
	reconnectToNatsAfter  = flag.Duration("reconnect-to-nats-after", 0, "reconnect to nats after this time periodically")
	resetLiveCapAfter     = flag.Duration("reset-live-cap-after", 613*time.Second, "reset the live capture setup after this amount of time")
	sleepFor              = flag.Duration("sleep", 0, "Sleep this long between sending data (only when replaying a file)")
	requestsFile          = flag.String("requests", "", "CSV file containing requests (IP and URL)")
	protocol              = flag.String("protocol", "http", "which protocol to parse: http or ajp13")
	useXForwardedAsSource = flag.Bool("use-x-forwarded", false, "Use the IP address in X-Forwarded-For as source")
	useVhostAsSource      = flag.Bool("use-vhost-as-source", false, "Use the Vhost as source")
	trace                 = flag.Bool("trace", false, "Trace the packet capturing")
	tailPoll              = flag.Bool("tail-poll", false, "use file polling to detect file changes when tailing logrfiles")
	apacheLog             = flag.String("apache-log", "", "Parse an Apache Log file")
	apacheReplay          = flag.String("apache-replay", "", "Apache log file to replay into the system")
	nginxLog              = flag.String("nginx-log", "", "Nginx log file to tail")
	nginxFormat           = flag.String("nginx-format", "", "The nginx log file format")
	nginxReplay           = flag.String("nginx-replay", "", "Replay this nginx logfile")
	hostName              = flag.String("hostname", "", "Override the captured hostname with this one")
	accessWatchKey        = flag.String("access-watch-key", "", "access.watch API key")
	configFile            = flag.String("config", "", "The location of the TOML config file")

	beQuiet   = flag.Bool("quiet", true, "Be quiet")
	doVersion = flag.Bool("version", false, "Show version information")

	natsEC          *nats.EncodedConn
	natsJSONEC      *nats.EncodedConn
	natsErrorChan   chan error
	natsIsAvailable bool
	count           uint64
	timeout         = -1 * time.Second
	ipPriv          *ip.IP
	config          Config

	// Version contains the program Version, e.g. 1.0.1
	Version string

	// BuildDate contains the date and time at which the program was compiled
	BuildDate string
)

// Config contains the program configuration
type Config struct {
	Live                  bool
	Interface             string
	SnapshotLen           int
	Filter                string
	Promiscuous           bool
	NatsURL               string
	NatsQueue             string
	NatsUser              string
	NatsPassword          string
	NatsCA                string
	SleepFor              duration
	RequestsFile          string
	UseXForwardedAsSource bool
	UseVhostAsSource      bool
	Quiet                 bool
	Protocol              string
	Trace                 bool
	TailPoll              bool
	ApacheLog             string
	ApacheReplay          string
	NginxLog              string
	NginxLogFormat        string
	NginxReplay           string
	HostName              string
	AccessWatchKey        string
	ReconnectToNatsAfter  duration
	ResetLiveCaptureAfter duration
}

type duration struct {
	time.Duration
}

func (d *duration) UnmarshalText(text []byte) error {
	var err error
	d.Duration, err = time.ParseDuration(string(text))
	return err
}

func (c Config) print() {
	fmt.Printf("Live:                  %t\n", c.Live)
	fmt.Printf("Interface:             %s\n", c.Interface)
	fmt.Printf("SnapshotLen:           %d\n", c.SnapshotLen)
	fmt.Printf("Filter:                %s\n", c.Filter)
	fmt.Printf("Promiscuous:           %t\n", c.Promiscuous)
	fmt.Printf("NatsURL:               %s\n", c.NatsURL)
	fmt.Printf("NatsQueue:             %s\n", c.NatsQueue)
	fmt.Printf("NatsUser:              %s\n", c.NatsUser)
	fmt.Printf("NatsPassword:          %s\n", c.NatsPassword)
	fmt.Printf("NatsCA:                %s\n", c.NatsCA)
	fmt.Printf("ReconnectToNatsAfter:  %s\n", c.ReconnectToNatsAfter.String())
	fmt.Printf("SleepFor:              %s\n", c.SleepFor.String())
	fmt.Printf("RequestsFile:          %s\n", c.RequestsFile)
	fmt.Printf("TailPoll:              %t\n", c.TailPoll)
	fmt.Printf("Apache Log:            %s\n", c.ApacheLog)
	fmt.Printf("Apache Replay:         %s\n", c.ApacheReplay)
	fmt.Printf("Nginx Log:             %s\n", c.NginxLog)
	fmt.Printf("Nginx Log Format:      %s\n", c.NginxLogFormat)
	fmt.Printf("NginxReplay:           %s\n", c.NginxReplay)
	fmt.Printf("HostName:              %s\n", c.HostName)
	fmt.Printf("AccessWatchKey:        %s\n", c.AccessWatchKey)
	fmt.Printf("UseXForwardedAsSource: %t\n", c.UseXForwardedAsSource)
	fmt.Printf("UseVhostAsSource:      %t\n", c.UseVhostAsSource)
	fmt.Printf("Protocol:              %s\n", c.Protocol)
	fmt.Printf("Reset Live Cap After:  %s\n", c.ResetLiveCaptureAfter.String())
	fmt.Printf("Quiet:                 %t\n", c.Quiet)
	fmt.Printf("Trace:                 %t\n", c.Trace)
}

func init() {
	flag.Parse()

	nats.RegisterEncoder(protobuf.PROTOBUF_ENCODER, &protobuf.ProtobufEncoder{})
}

func main() {
	if *doVersion {
		version()
		os.Exit(0)
	}

	loadConfig()

	// Output how many requests per second were sent
	if !config.Quiet {
		go func(c *uint64) {
			for {
				fmt.Printf("%d requests per second\n", *c)
				*c = 0
				time.Sleep(time.Second)
			}
		}(&count)
	}

	// NATS
	//
	if config.NatsURL == "" && config.AccessWatchKey == "" {
		log.Fatal("No NATS URL specified (-nats-url)!")
	}

	natsIsAvailable = false
	natsErrorChan = make(chan error, 1)

	err := connectToNATS()
	if err != nil && config.AccessWatchKey == "" {
		log.Fatal(err)
	}

	// reconnect to nats periodically
	//
	if *reconnectToNatsAfter > time.Minute {
		go func(interval time.Duration) {
			for range time.Tick(interval) {
				if config.Trace {
					log.Printf("reconnecting to NATS")
				}
				natsEC.Conn.Close()
				natsJSONEC.Conn.Close()
			}
		}(*reconnectToNatsAfter)
	}

	go natsWatchdog(natsErrorChan)

	// What should I do?
	if config.RequestsFile != "" {
		replayFile()
	} else if config.ApacheReplay != "" {
		apacheLogReplay(config.ApacheReplay)
	} else if config.NginxReplay != "" {
		nginxLogReplay(config.NginxReplay, config.NginxLogFormat)
	} else if config.ApacheLog != "" {
		apacheLogCapture(config.ApacheLog)
	} else if config.Live {
		fmt.Printf("live capture (%s, %s) to %s\n", config.Interface, config.Filter, config.NatsURL)
		liveCapture()
	} else if config.NginxLog != "" && config.NginxLogFormat != "" {
		nginxLogCapture(config.NginxLog, config.NginxLogFormat)
	}
}

func natsWatchdog(closedChan chan error) {
	var lastError error

	for err := range closedChan {
		if lastError != err {
			lastError = err
			log.Println(err)
		}

		if err != nats.ErrConnectionClosed {
			continue
		}

	RECONNECT:
		for {
			log.Printf("Reconnecting to NATS at %s\n", *natsURL)
			err := connectToNATS()
			if err == nil {
				break RECONNECT
			}
			time.Sleep(1 * time.Second)
		}
	}
}

func connectToNATS() error {
	var natsConn *nats.Conn
	var err error

	if config.NatsUser != "" && config.NatsPassword != "" && config.NatsCA != "" {
		natsConn, err = nats.Connect(config.NatsURL, nats.UserInfo(config.NatsUser, config.NatsPassword), nats.RootCAs(config.NatsCA))
	} else {
		if config.NatsPassword != "" && config.NatsUser != "" {
			natsConn, err = nats.Connect(config.NatsURL, nats.UserInfo(config.NatsUser, config.NatsPassword))
		} else {
			natsConn, err = nats.Connect(config.NatsURL)
		}
	}
	if err != nil {
		return err
	}

	natsEC, err = nats.NewEncodedConn(natsConn, protobuf.PROTOBUF_ENCODER)
	if err != nil {
		return fmt.Errorf("Encoded Connection: %v", err)
	}

	natsJSONEC, err = nats.NewEncodedConn(natsConn, nats.JSON_ENCODER)
	if err != nil {
		return fmt.Errorf("Encoded Connection: %v", err)
	}

	natsIsAvailable = true
	return nil
}

type liveCap struct {
	filter      string
	device      string
	promisc     bool
	snapshotLen int
	handle      *pcap.Handle
	packetChan  chan gopacket.Packet
}

func newLiveCap(device string, filter string, snapshotLen int, promisc bool) (*liveCap, error) {
	lc := &liveCap{
		filter:      filter,
		device:      device,
		promisc:     promisc,
		snapshotLen: snapshotLen,
	}

	err := lc.SetupCap()
	if err != nil {
		return nil, err
	}

	return lc, nil
}

func (lc *liveCap) SetupCap() error {
	if !config.Quiet {
		log.Printf("reading incoming HTTP requests on %s %s\n", config.Interface, config.Filter)
	}

	if lc.handle != nil {
		lc.handle.Close()
	}

	// PCAP setup
	//
	handle, err := pcap.OpenLive(config.Interface, int32(config.SnapshotLen), config.Promiscuous, timeout)
	if err != nil {
		return err
	}
	// defer handle.Close()
	lc.handle = handle

	err = lc.handle.SetBPFFilter(config.Filter)
	if err != nil {
		return err
	}

	packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
	lc.packetChan = packetSource.Packets()

	return nil
}

func liveCapture() {
	ipPriv = ip.NewIP()

	livecap, err := newLiveCap(config.Interface, config.Filter, config.SnapshotLen, config.Promiscuous)
	if err != nil {
		log.Fatal(err)
	}

	closeChan := time.Tick(config.ResetLiveCaptureAfter.Duration)

	for {
		select {
		case <-closeChan:
			livecap.SetupCap()
		case p := <-livecap.packetChan:
			go processPacket(p)
		}
	}
}

func writeLogToWatch(r *data.Request) {
	h := map[string]string{}

	if r.AcceptEncoding != "" {
		h["Accept-Encoding"] = r.AcceptEncoding
	}

	if r.Accept != "" {
		h["Accept"] = r.Accept
	}

	if r.AcceptLanguage != "" {
		h["Accept-Language"] = r.AcceptLanguage
	}

	if r.Cookie != "" {
		h["Cookie"] = r.Cookie
	}

	if r.Host != "" {
		h["Host"] = r.Host
	}

	if r.Referer != "" {
		h["Referer"] = r.Referer
	}

	if r.UserAgent != "" {
		h["User-Agent"] = r.UserAgent
	}

	if r.Via != "" {
		h["Via"] = r.Via
	}

	if r.XForwardedFor != "" {
		h["X-Forwarded-For"] = r.XForwardedFor
	}

	if r.XRequestedWith != "" {
		h["X-Requested-With"] = r.XRequestedWith
	}

	data := map[string]interface{}{
		"request": map[string]interface{}{
			"time":     time.Unix(0, r.CreatedAt),
			"address":  r.Source,
			"protocol": r.Protocol,
			"scheme":   "https",
			"method":   r.Method,
			"url":      r.Url,
			"headers":  h,
		},
		"response": map[string]interface{}{"status": "200"},
	}

	jdata, err := json.Marshal(data)

	client := &http.Client{}

	fmt.Println(string(jdata))
	buf := bytes.NewBuffer(jdata)
	req, err := http.NewRequest("POST", "https://log.access.watch/1.1/log", buf)
	req.Header.Add("Api-Key", config.AccessWatchKey)
	req.Header.Add("Accept", "application/json")
	req.Header.Add("Content-Type", "application/json")
	resp, err := client.Do(req)
	if err != nil {
		log.Println(err)
	}
	resp.Body.Close()
}

func publishRequest(queue string, request *data.Request) {
	if config.AccessWatchKey != "" {
		writeLogToWatch(request)
		return
	}

	if !natsIsAvailable {
		if rand.Intn(100) == 0 {
			log.Println("nats connection is not available")
		}
		return
	}

	if err := natsEC.Publish(config.NatsQueue, request); err != nil {
		natsErrorChan <- err
		if err == nats.ErrConnectionClosed {
			natsIsAvailable = false
		}
	}

}

// processPacket receives a raw packet from pcap, builds a Request item from it and sends it to the queue
func processPacket(packet gopacket.Packet) {

	hasIPv4 := false
	var ipSrc, ipDst string

	// IPv4
	if ipLayer := packet.Layer(layers.LayerTypeIPv4); ipLayer != nil {
		ip := ipLayer.(*layers.IPv4)
		ipSrc = ip.SrcIP.String()
		ipDst = ip.DstIP.String()
		hasIPv4 = true
	}

	// IPv6
	if !hasIPv4 {
		if ipLayer := packet.Layer(layers.LayerTypeIPv6); ipLayer != nil {
			ip := ipLayer.(*layers.IPv6)
			ipSrc = ip.SrcIP.String()
			ipDst = ip.DstIP.String()
		}
	}

	// TCP
	tcpLayer := packet.Layer(layers.LayerTypeTCP)
	if tcpLayer == nil {
		return
	}

	tcp, _ := tcpLayer.(*layers.TCP)

	portSrc := tcp.SrcPort
	portDst := tcp.DstPort
	sequence := tcp.Seq

	applicationLayer := packet.ApplicationLayer()
	if applicationLayer == nil {
		return
	}

	count++

	if len(applicationLayer.Payload()) < 50 {
		log.Println("application layer too small!")
		return
	}

	request := data.Request{
		IpSrc:     ipSrc,
		IpDst:     ipDst,
		PortSrc:   uint32(portSrc),
		PortDst:   uint32(portDst),
		TcpSeq:    uint32(sequence),
		CreatedAt: packet.Metadata().CaptureInfo.Timestamp.UnixNano(),
	}

	switch config.Protocol {
	case "http":
		err := processHTTP(&request, applicationLayer.Payload())
		if err != nil {
			log.Println(err)
			return
		}
	case "ajp13":
		err := processAJP13(&request, applicationLayer.Payload())
		if err != nil {
			log.Println(err)
			return
		}
	}

	if config.UseXForwardedAsSource && request.XForwardedFor != "" {
		if strings.Contains(request.XForwardedFor, ",") {
			ips := strings.Split(request.XForwardedFor, ",")
			for i := len(ips) - 1; i >= 0; i-- {
				ipRaw := strings.TrimSpace(ips[i])
				ipAddr := net.ParseIP(ipRaw)
				if ipAddr != nil && !ipPriv.IsPrivate(ipAddr) {
					request.Source = ipRaw
					break
				}
			}
		} else {
			ipAddr := net.ParseIP(strings.TrimSpace(request.XForwardedFor))

			if !ipPriv.IsPrivate(ipAddr) {
				request.Source = request.XForwardedFor
			}
		}
	}

	if request.Source == request.IpSrc && request.XRealIP != "" {
		request.Source = request.XRealIP
	}

	if config.Trace {
		log.Printf("[%s] %s\n", request.Source, request.Url)
	}

	publishRequest(config.NatsQueue, &request)
}

func processAJP13(request *data.Request, appData []byte) error {

	a, err := ajp13.Parse(appData)
	if err != nil {
		return fmt.Errorf("Failed to parse AJP13 request: %s", err)
	}

	request.Url = a.URI
	request.Method = a.Method()
	request.Host = a.Server
	request.Protocol = a.Version
	request.Origin = a.RemoteAddr.String()
	request.Source = a.RemoteAddr.String()

	if v, ok := a.Header("Referer"); ok {
		request.Referer = v
	}

	if v, ok := a.Header("Connection"); ok {
		request.Connection = v
	}

	if v, ok := a.Header("X-Forwarded-For"); ok {
		request.XForwardedFor = v
	}

	if v, ok := a.Header("X-Real-IP"); ok {
		request.XRealIP = v
	}

	if v, ok := a.Header("X-Requested-With"); ok {
		request.XRequestedWith = v
	}

	if v, ok := a.Header("Accept-Encoding"); ok {
		request.AcceptEncoding = v
	}

	if v, ok := a.Header("Accept-Language"); ok {
		request.AcceptLanguage = v
	}

	if v, ok := a.Header("User-Agent"); ok {
		request.UserAgent = v
	}

	if v, ok := a.Header("Accept"); ok {
		request.Accept = v
	}

	if v, ok := a.Header("Cookie"); ok {
		request.Cookie = v
	}

	if v, ok := a.Header("X-Forwarded-Host"); ok {
		if v != request.Host {
			request.Host = v
		}
	}

	return nil
}

func processHTTP(request *data.Request, appData []byte) error {
	reader := bufio.NewReader(strings.NewReader(string(appData)))

	req, err := http.ReadRequest(reader)
	if err != nil {
		return fmt.Errorf("Failed to parse HTTP header: %s", err)
	}

	request.Url = req.URL.String()
	request.Method = req.Method
	request.Referer = req.Referer()
	request.Host = req.Host
	request.Protocol = req.Proto
	request.Origin = request.Host
	if _, ok := req.Header["Connection"]; ok {
		request.Connection = req.Header["Connection"][0]
	}
	if _, ok := req.Header["X-Forwarded-For"]; ok {
		request.XForwardedFor = req.Header["X-Forwarded-For"][0]
	}
	// CloudFlare: override X-Forwarded for since it is tainted by cloudflare
	if _, ok := req.Header["True-Client-Ip"]; ok {
		request.XForwardedFor = req.Header["True-Client-Ip"][0]
	}
	if _, ok := req.Header["X-Real-Ip"]; ok {
		request.XRealIP = req.Header["X-Real-Ip"][0]
	}
	if _, ok := req.Header["X-Requested-With"]; ok {
		request.XRequestedWith = req.Header["X-Requested-With"][0]
	}
	if _, ok := req.Header["Accept-Encoding"]; ok {
		request.AcceptEncoding = req.Header["Accept-Encoding"][0]
	}
	if _, ok := req.Header["Accept-Language"]; ok {
		request.AcceptLanguage = req.Header["Accept-Language"][0]
	}
	if _, ok := req.Header["User-Agent"]; ok {
		request.UserAgent = req.Header["User-Agent"][0]
	}
	if _, ok := req.Header["Accept"]; ok {
		request.Accept = req.Header["Accept"][0]
	}
	if _, ok := req.Header["Cookie"]; ok {
		request.Cookie = req.Header["Cookie"][0]
	}

	request.Source = request.IpSrc

	return nil
}

// replayFile takes a file containing a list of requests (SourceIP Url) and queues the requests
// e.g.
// 157.55.39.229 /gross-gerau/12012260-beate-anstatt
// 103.232.100.98 /weinsheim-eifel/13729444-plus-warenhandelsges-mbh
func replayFile() {
	var req data.Request
	var startTs time.Time
	var endTs time.Time

	rand.Seed(time.Now().UnixNano())

	for {
		fh, err := os.Open(config.RequestsFile)
		if err != nil {
			log.Fatalf("Failed to open request file '%s': %s", config.RequestsFile, err)
		}

		c := csv.NewReader(fh)
		c.Comma = ' '

		for {
			if config.SleepFor.Duration > time.Nanosecond {
				startTs = time.Now()
			}

			r, err := c.Read()

			if err == io.EOF {
				break
			}

			if err != nil {
				log.Println(err)
				continue
			}

			req.IpSrc = r[0]
			req.Source = r[0]
			req.Url = r[1]
			req.UserAgent = "Munch/1.0"
			req.Host = "demo.scraperwall.com"
			req.CreatedAt = time.Now().UnixNano()

			publishRequest(config.NatsQueue, &req)

			if strings.Index(r[1], ".") < 0 {
				hash := sha1.New()
				io.WriteString(hash, r[0])
				fp := data.Fingerprint{
					ClientID:    "scw",
					Fingerprint: fmt.Sprintf("%x", hash.Sum(nil)),
					Remote:      r[0],
					Url:         r[1],
					Source:      r[0],
					CreatedAt:   time.Now(),
				}

				if strings.HasPrefix(r[0], "50.31.") {
					fp.Fingerprint = "a1f2c2ee560ce6580d66d451a9c8dfbf"
					natsJSONEC.Publish("fingerprints_scw", fp)
				} else if rand.Intn(10) < 5 {
					natsJSONEC.Publish("fingerprints_scw", fp)
				}

			}

			count++
			if config.SleepFor.Duration >= time.Nanosecond {
				endTs = time.Now()
				if endTs.Before(startTs.Add(config.SleepFor.Duration)) {
					time.Sleep(config.SleepFor.Duration - endTs.Sub(startTs))
				}
			}
		}
	}
}

func loadConfig() {

	// initialize with values from the command line / environment
	config.Live = *doLiveCapture
	config.Interface = *iface
	config.SnapshotLen = *snapshotLen
	config.Filter = *filter
	config.Promiscuous = *promiscuous
	config.NatsURL = *natsURL
	config.NatsQueue = *natsQueue
	config.NatsUser = *natsUser
	config.NatsPassword = *natsPassword
	config.NatsCA = *natsCA
	config.SleepFor.Duration = *sleepFor
	config.RequestsFile = *requestsFile
	config.UseXForwardedAsSource = *useXForwardedAsSource
	config.UseVhostAsSource = *useVhostAsSource
	config.Protocol = *protocol
	config.TailPoll = *tailPoll
	config.ApacheLog = *apacheLog
	config.ApacheReplay = *apacheReplay
	config.NginxLog = *nginxLog
	config.NginxLogFormat = *nginxFormat
	config.NginxReplay = *nginxReplay
	config.HostName = *hostName
	config.Quiet = *beQuiet
	config.Trace = *trace
	config.AccessWatchKey = *accessWatchKey
	config.ReconnectToNatsAfter.Duration = *reconnectToNatsAfter
	config.ResetLiveCaptureAfter.Duration = *resetLiveCapAfter

	if *configFile == "" {
		return
	}

	_, err := os.Stat(*configFile)
	if err != nil {
		log.Printf("%s: %s\n", *configFile, err)
		return
	}

	if _, err = toml.DecodeFile(*configFile, &config); err != nil {
		log.Printf("%s: %s\n", *configFile, err)
	}

	if !config.Quiet {
		config.print()
	}
}

// version outputs build information...
func version() {
	fmt.Printf("munchclient %s, built on %s\n", Version, BuildDate)
}