package jwxtest

import (
	"bytes"
	"context"
	"crypto/ecdh"
	"crypto/ecdsa"
	"crypto/ed25519"
	"crypto/rand"
	"crypto/rsa"
	"encoding/json"
	"fmt"
	"io"
	"os"
	"strings"
	"testing"

	"github.com/lestrrat-go/jwx/v3/internal/tokens"
	"github.com/lestrrat-go/jwx/v3/jwa"
	"github.com/lestrrat-go/jwx/v3/jwe"
	"github.com/lestrrat-go/jwx/v3/jwk"
	ourecdsa "github.com/lestrrat-go/jwx/v3/jwk/ecdsa"
	"github.com/lestrrat-go/jwx/v3/jws"
	"github.com/stretchr/testify/require"
)

func GenerateRsaKey() (*rsa.PrivateKey, error) {
	return rsa.GenerateKey(rand.Reader, 2048)
}

func GenerateRsaJwk() (jwk.Key, error) {
	key, err := GenerateRsaKey()
	if err != nil {
		return nil, fmt.Errorf(`failed to generate RSA private key: %w`, err)
	}

	k, err := jwk.Import(key)
	if err != nil {
		return nil, fmt.Errorf(`failed to generate jwk.RSAPrivateKey: %w`, err)
	}

	return k, nil
}

func GenerateRsaPublicJwk() (jwk.Key, error) {
	key, err := GenerateRsaJwk()
	if err != nil {
		return nil, fmt.Errorf(`failed to generate jwk.RSAPrivateKey: %w`, err)
	}

	return jwk.PublicKeyOf(key)
}

func GenerateEcdsaKey(alg jwa.EllipticCurveAlgorithm) (*ecdsa.PrivateKey, error) {
	crv, err := ourecdsa.CurveFromAlgorithm(alg)
	if err != nil {
		return nil, fmt.Errorf(`unknown elliptic curve algorithm: %w`, err)
	}

	return ecdsa.GenerateKey(crv, rand.Reader)
}

func GenerateEcdsaJwk() (jwk.Key, error) {
	key, err := GenerateEcdsaKey(jwa.P521())
	if err != nil {
		return nil, fmt.Errorf(`failed to generate ECDSA private key: %w`, err)
	}

	k, err := jwk.Import(key)
	if err != nil {
		return nil, fmt.Errorf(`failed to generate jwk.ECDSAPrivateKey: %w`, err)
	}

	return k, nil
}

func GenerateEcdsaPublicJwk() (jwk.Key, error) {
	key, err := GenerateEcdsaJwk()
	if err != nil {
		return nil, fmt.Errorf(`failed to generate jwk.ECDSAPrivateKey: %w`, err)
	}

	return jwk.PublicKeyOf(key)
}

func GenerateSymmetricKey() []byte {
	sharedKey := make([]byte, 64)
	rand.Read(sharedKey)
	return sharedKey
}

func GenerateSymmetricJwk() (jwk.Key, error) {
	key, err := jwk.Import(GenerateSymmetricKey())
	if err != nil {
		return nil, fmt.Errorf(`failed to generate jwk.SymmetricKey: %w`, err)
	}

	return key, nil
}

func GenerateEd25519Key() (ed25519.PrivateKey, error) {
	_, priv, err := ed25519.GenerateKey(rand.Reader)
	return priv, err
}

func GenerateEd25519Jwk() (jwk.Key, error) {
	key, err := GenerateEd25519Key()
	if err != nil {
		return nil, fmt.Errorf(`failed to generate Ed25519 private key: %w`, err)
	}

	k, err := jwk.Import(key)
	if err != nil {
		return nil, fmt.Errorf(`failed to generate jwk.OKPPrivateKey: %w`, err)
	}

	return k, nil
}

func GenerateX25519Key() (*ecdh.PrivateKey, error) {
	priv, err := ecdh.X25519().GenerateKey(rand.Reader)
	return priv, err
}

func GenerateX25519Jwk() (jwk.Key, error) {
	key, err := GenerateX25519Key()
	if err != nil {
		return nil, fmt.Errorf(`failed to generate X25519 private key: %w`, err)
	}

	k, err := jwk.Import(key)
	if err != nil {
		return nil, fmt.Errorf(`failed to generate jwk.OKPPrivateKey: %w`, err)
	}

	return k, nil
}

func WriteFile(dir, template string, src io.Reader) (string, func(), error) {
	file, cleanup, err := CreateTempFile(dir, template)
	if err != nil {
		return "", nil, fmt.Errorf(`failed to create temporary file: %w`, err)
	}

	if _, err := io.Copy(file, src); err != nil {
		defer cleanup()
		return "", nil, fmt.Errorf(`failed to copy content to temporary file: %w`, err)
	}

	if err := file.Sync(); err != nil {
		defer cleanup()
		return "", nil, fmt.Errorf(`failed to sync file: %w`, err)
	}
	return file.Name(), cleanup, nil
}

func WriteJSONFile(dir, template string, v any) (string, func(), error) {
	var buf bytes.Buffer

	enc := json.NewEncoder(&buf)
	if err := enc.Encode(v); err != nil {
		return "", nil, fmt.Errorf(`failed to encode object to JSON: %w`, err)
	}
	return WriteFile(dir, template, &buf)
}

func DumpFile(t *testing.T, file string) {
	buf, err := os.ReadFile(file)
	require.NoError(t, err, `failed to read file %s for debugging`, file)

	if isHash, isArray := bytes.ContainsRune(buf, tokens.OpenCurlyBracket), bytes.ContainsRune(buf, tokens.OpenSquareBracket); isHash || isArray {
		// Looks like a JSON-like thing. Dump that in a formatted manner, and
		// be done with it

		var v any
		if isHash {
			v = map[string]any{}
		} else {
			v = []any{}
		}

		require.NoError(t, json.Unmarshal(buf, &v), `failed to parse contents as JSON`)

		buf, _ = json.MarshalIndent(v, "", "  ")
		t.Logf("=== BEGIN %s (formatted JSON) ===", file)
		t.Logf("%s", buf)
		t.Logf("=== END   %s (formatted JSON) ===", file)
		return
	}

	// If the contents do not look like JSON, then we attempt to parse each content
	// based on heuristics (from its file name) and do our best
	t.Logf("=== BEGIN %s (raw) ===", file)
	t.Logf("%s", buf)
	t.Logf("=== END   %s (raw) ===", file)

	if strings.HasSuffix(file, ".jwe") {
		// cross our fingers our jwe implementation works
		m, err := jwe.Parse(buf)
		require.NoError(t, err, `failed to parse JWE encrypted message`)

		buf, _ = json.MarshalIndent(m, "", "  ")
	}

	t.Logf("=== BEGIN %s (formatted JSON) ===", file)
	t.Logf("%s", buf)
	t.Logf("=== END   %s (formatted JSON) ===", file)
}

func CreateTempFile(dir, template string) (*os.File, func(), error) {
	file, err := os.CreateTemp(dir, template)
	if err != nil {
		return nil, nil, fmt.Errorf(`failed to create temporary file: %w`, err)
	}

	cleanup := func() {
		file.Close()
		os.Remove(file.Name())
	}

	return file, cleanup, nil
}

func ReadFile(file string) ([]byte, error) {
	f, err := os.Open(file)
	if err != nil {
		return nil, fmt.Errorf(`failed to open file %s: %w`, file, err)
	}
	defer f.Close()

	buf, err := io.ReadAll(f)
	if err != nil {
		return nil, fmt.Errorf(`failed to read from key file %s: %w`, file, err)
	}

	return buf, nil
}

func ParseJwkFile(_ context.Context, file string) (jwk.Key, error) {
	buf, err := ReadFile(file)
	if err != nil {
		return nil, fmt.Errorf(`failed to read from key file %s: %w`, file, err)
	}

	key, err := jwk.ParseKey(buf)
	if err != nil {
		return nil, fmt.Errorf(`filed to parse JWK in key file %s: %w`, file, err)
	}

	return key, nil
}

func DecryptJweFile(ctx context.Context, file string, alg jwa.KeyEncryptionAlgorithm, jwkfile string) ([]byte, error) {
	key, err := ParseJwkFile(ctx, jwkfile)
	if err != nil {
		return nil, fmt.Errorf(`failed to parse keyfile %s: %w`, file, err)
	}

	buf, err := ReadFile(file)
	if err != nil {
		return nil, fmt.Errorf(`failed to read from encrypted file %s: %w`, file, err)
	}

	var rawkey any
	if err := jwk.Export(key, &rawkey); err != nil {
		return nil, fmt.Errorf(`failed to obtain raw key from JWK: %w`, err)
	}

	return jwe.Decrypt(buf, jwe.WithKey(alg, rawkey))
}

func EncryptJweFile(ctx context.Context, dir string, payload []byte, keyalg jwa.KeyEncryptionAlgorithm, keyfile string, contentalg jwa.ContentEncryptionAlgorithm, compressalg jwa.CompressionAlgorithm) (string, func(), error) {
	key, err := ParseJwkFile(ctx, keyfile)
	if err != nil {
		return "", nil, fmt.Errorf(`failed to parse keyfile %s: %w`, keyfile, err)
	}

	var keyif any

	switch keyalg {
	case jwa.RSA1_5(), jwa.RSA_OAEP(), jwa.RSA_OAEP_256(), jwa.RSA_OAEP_384(), jwa.RSA_OAEP_512():
		var rawkey rsa.PrivateKey
		if err := jwk.Export(key, &rawkey); err != nil {
			return "", nil, fmt.Errorf(`failed to obtain raw key: %w`, err)
		}
		keyif = rawkey.PublicKey
	case jwa.ECDH_ES(), jwa.ECDH_ES_A128KW(), jwa.ECDH_ES_A192KW(), jwa.ECDH_ES_A256KW():
		var rawkey ecdsa.PrivateKey
		if err := jwk.Export(key, &rawkey); err != nil {
			return "", nil, fmt.Errorf(`failed to obtain raw key: %w`, err)
		}
		keyif = rawkey.PublicKey
	default:
		var rawkey []byte
		if err := jwk.Export(key, &rawkey); err != nil {
			return "", nil, fmt.Errorf(`failed to obtain raw key: %w`, err)
		}
		keyif = rawkey
	}

	buf, err := jwe.Encrypt(payload, jwe.WithKey(keyalg, keyif), jwe.WithContentEncryption(contentalg), jwe.WithCompress(compressalg))
	if err != nil {
		return "", nil, fmt.Errorf(`failed to encrypt payload: %w`, err)
	}

	return WriteFile(dir, "jwx-test-*.jwe", bytes.NewReader(buf))
}

func VerifyJwsFile(ctx context.Context, file string, alg jwa.SignatureAlgorithm, jwkfile string) ([]byte, error) {
	key, err := ParseJwkFile(ctx, jwkfile)
	if err != nil {
		return nil, fmt.Errorf(`failed to parse keyfile %s: %w`, file, err)
	}

	buf, err := ReadFile(file)
	if err != nil {
		return nil, fmt.Errorf(`failed to read from encrypted file %s: %w`, file, err)
	}

	var rawkey, pubkey any
	if err := jwk.Export(key, &rawkey); err != nil {
		return nil, fmt.Errorf(`failed to obtain raw key from JWK: %w`, err)
	}
	pubkey = rawkey
	switch tkey := rawkey.(type) {
	case *ecdsa.PrivateKey:
		pubkey = tkey.PublicKey
	case *rsa.PrivateKey:
		pubkey = tkey.PublicKey
	case *ed25519.PrivateKey:
		pubkey = tkey.Public()
	}

	return jws.Verify(buf, jws.WithKey(alg, pubkey))
}

func SignJwsFile(ctx context.Context, dir string, payload []byte, alg jwa.SignatureAlgorithm, keyfile string) (string, func(), error) {
	key, err := ParseJwkFile(ctx, keyfile)
	if err != nil {
		return "", nil, fmt.Errorf(`failed to parse keyfile %s: %w`, keyfile, err)
	}

	buf, err := jws.Sign(payload, jws.WithKey(alg, key))
	if err != nil {
		return "", nil, fmt.Errorf(`failed to sign payload: %w`, err)
	}

	return WriteFile(dir, "jwx-test-*.jws", bytes.NewReader(buf))
}
