// SPDX-FileCopyrightText: Copyright The Lima Authors
// SPDX-License-Identifier: Apache-2.0

package portfwd

import (
	"context"
	"fmt"
	"net"
	"path/filepath"
	"strconv"

	"github.com/sirupsen/logrus"
)

func Listen(ctx context.Context, listenConfig net.ListenConfig, hostAddress string) (net.Listener, error) {
	if filepath.IsAbs(hostAddress) {
		// Handle Unix domain sockets
		if err := prepareUnixSocket(hostAddress); err != nil {
			return nil, err
		}
		var lc net.ListenConfig
		unixLis, err := lc.Listen(ctx, "unix", hostAddress)
		if err != nil {
			logListenError(err, "unix", hostAddress)
			return nil, err
		}
		return unixLis, nil
	}
	localIPStr, localPortStr, _ := net.SplitHostPort(hostAddress)
	localIP := net.ParseIP(localIPStr)
	localPort, _ := strconv.Atoi(localPortStr)

	if !localIP.Equal(IPv4loopback1) || localPort >= 1024 {
		tcpLis, err := listenConfig.Listen(ctx, "tcp", hostAddress)
		if err != nil {
			logListenError(err, "tcp", hostAddress)
			return nil, err
		}
		return tcpLis, nil
	}
	hostAddressPseudo := net.JoinHostPort("0.0.0.0", localPortStr)
	tcpLis, err := listenConfig.Listen(ctx, "tcp", hostAddressPseudo)
	if err != nil {
		logListenError(err, "tcp", hostAddressPseudo)
		return nil, err
	}
	return &pseudoLoopbackListener{tcpLis}, nil
}

func ListenPacket(ctx context.Context, listenConfig net.ListenConfig, hostAddress string) (net.PacketConn, error) {
	localIPStr, localPortStr, _ := net.SplitHostPort(hostAddress)
	localIP := net.ParseIP(localIPStr)
	localPort, _ := strconv.Atoi(localPortStr)

	if !localIP.Equal(IPv4loopback1) || localPort >= 1024 {
		udpConn, err := listenConfig.ListenPacket(ctx, "udp", hostAddress)
		if err != nil {
			logListenError(err, "udp", hostAddress)
			return nil, err
		}
		return udpConn, nil
	}
	hostAddressPseudo := net.JoinHostPort("0.0.0.0", localPortStr)
	udpConn, err := listenConfig.ListenPacket(ctx, "udp", hostAddressPseudo)
	if err != nil {
		logListenError(err, "udp", hostAddressPseudo)
		return nil, err
	}
	return &pseudoLoopbackPacketConn{udpConn}, nil
}

type pseudoLoopbackListener struct {
	net.Listener
}

func (p pseudoLoopbackListener) Accept() (net.Conn, error) {
	conn, err := p.Listener.Accept()
	if err != nil {
		return nil, err
	}

	remoteAddr := conn.RemoteAddr().String() // ip:port
	remoteAddrIP, _, err := net.SplitHostPort(remoteAddr)
	if err != nil {
		logrus.WithError(err).Debugf("pseudoloopback forwarder: rejecting non-loopback remoteAddr %q (unparsable)", remoteAddr)
		conn.Close()
		return nil, err
	}
	if !IsLoopback(remoteAddrIP) {
		err := fmt.Errorf("pseudoloopback forwarder: rejecting non-loopback remoteAddr %q", remoteAddr)
		logrus.Debug(err)
		conn.Close()
		return nil, err
	}
	logrus.Infof("pseudoloopback forwarder: accepting connection from %q", remoteAddr)
	return conn, nil
}

type pseudoLoopbackPacketConn struct {
	net.PacketConn
}

func (pk *pseudoLoopbackPacketConn) ReadFrom(bytes []byte) (n int, addr net.Addr, err error) {
	n, remoteAddr, err := pk.PacketConn.ReadFrom(bytes)
	if err != nil {
		return 0, nil, err
	}

	remoteAddrIP, _, err := net.SplitHostPort(remoteAddr.String())
	if err != nil {
		return 0, nil, err
	}
	if !IsLoopback(remoteAddrIP) {
		return 0, nil, fmt.Errorf("pseudoloopback forwarder: rejecting non-loopback remoteAddr %q", remoteAddr)
	}
	return n, remoteAddr, nil
}

func (pk *pseudoLoopbackPacketConn) WriteTo(bytes []byte, remoteAddr net.Addr) (n int, err error) {
	remoteAddrIP, _, err := net.SplitHostPort(remoteAddr.String())
	if err != nil {
		return 0, err
	}
	if !IsLoopback(remoteAddrIP) {
		return 0, fmt.Errorf("pseudoloopback forwarder: rejecting non-loopback remoteAddr %q", remoteAddr)
	}
	return pk.PacketConn.WriteTo(bytes, remoteAddr)
}

func IsLoopback(addr string) bool {
	return net.ParseIP(addr).IsLoopback()
}
