package dotnet

import (
	"encoding/base64"
	"fmt"
	"strings"

	"github.com/vulncheck-oss/go-exploit/output"
	"github.com/vulncheck-oss/go-exploit/transform"
)

type Record interface {
	GetRecordType() int
	ToRecordBin() (string, bool)
	// TOXML impls, exist to convert a given record into the expected SOAP XML element for the SOAP formatter. Not all records have been implemented.
	ToXML(ClassInfo, MemberTypeInfo, BinaryLibraryRecord, int, string) (MemberNode, bool)
}

type MemberPrimitiveTypedRecord struct {
	PrimitiveTypeEnum int
	Value             Primitive
}

type BinaryArrayRecord struct {
	ObjectID            int
	BinaryArrayTypeEnum int // 1byte
	Rank                int
	Lengths             []int
	LowerBounds         []int
	TypeEnum            int // 1byte
	AdditionalTypeInfo  []any
}

type ClassWithIDRecord struct {
	ObjectID     int
	MetadataID   int
	MemberValues []any
}

type BinaryLibraryRecord struct {
	ID      int
	Library string
}

type SystemClassWithMembersAndTypesRecord struct {
	ClassInfo      ClassInfo
	MemberTypeInfo MemberTypeInfo
	MemberValues   []any
}

type ClassWithMembersAndTypesRecord struct {
	ClassInfo      ClassInfo
	MemberTypeInfo MemberTypeInfo
	LibraryID      int
	MemberValues   []any
	BinaryLibrary  BinaryLibraryRecord // Not _really_ supposed to be here per MSDN but I placed it here for convenience
}

type SerializationHeaderRecord struct {
	HeaderID int
	RootID   int
}

type MemberReferenceRecord struct {
	IDRef int
}

type ObjectNullMultiple256Record struct {
	NullCount int
}

type ObjectNullRecord struct{}

type BinaryObjectString struct {
	ObjectID int
	Value    string
}

type ArrayInfo struct {
	ObjectID    int
	MemberCount int
}

type ArraySinglePrimitiveRecord struct {
	PrimitiveTypeEnum int
	ArrayInfo         ArrayInfo
	Members           string // this will be a hex byte string "\x00\xwhatever"
}

type ArraySingleStringRecord struct {
	ArrayInfo ArrayInfo
	Members   []any
}

type ArraySingleObjectRecord struct {
	ArrayInfo ArrayInfo
	Members   []any
}

func (objectNullMultiple256Record ObjectNullMultiple256Record) GetRecordType() int {
	return RecordTypeEnumMap["ObjectNullMultiple256"]
}

func (arraySinglePrimitiveRecord ArraySinglePrimitiveRecord) GetRecordType() int {
	return RecordTypeEnumMap["ArraySinglePrimitive"]
}

func (binaryArrayRecord BinaryArrayRecord) GetRecordType() int {
	return RecordTypeEnumMap["BinaryArray"]
}

func (arraySingleObjectRecord ArraySingleObjectRecord) GetRecordType() int {
	return RecordTypeEnumMap["ArraySingleObject"]
}

func (arraySingleStringRecord ArraySingleStringRecord) GetRecordType() int {
	return RecordTypeEnumMap["ArraySingleString"]
}

func (classWithIDRecord ClassWithIDRecord) GetRecordType() int {
	return RecordTypeEnumMap["ClassWithId"]
}

func (binaryObjectString BinaryObjectString) GetRecordType() int {
	return RecordTypeEnumMap["BinaryObjectString"]
}

func (classWithMembersAndTypesRecord ClassWithMembersAndTypesRecord) GetRecordType() int {
	return RecordTypeEnumMap["ClassWithMembersAndTypes"]
}

func (systemClassWithMembersAndTypesRecord SystemClassWithMembersAndTypesRecord) GetRecordType() int {
	return RecordTypeEnumMap["SystemClassWithMembersAndTypes"]
}

func (serializationHeaderRecord SerializationHeaderRecord) GetRecordType() int {
	return RecordTypeEnumMap["SerializedStreamHeader"]
}

func (binaryLibraryRecord BinaryLibraryRecord) GetRecordType() int {
	return RecordTypeEnumMap["BinaryLibrary"]
}

func (memberReferenceRecord MemberReferenceRecord) GetRecordType() int {
	return RecordTypeEnumMap["MemberReference"]
}

func (memberPrimitiveTypedRecord MemberPrimitiveTypedRecord) GetRecordType() int {
	return RecordTypeEnumMap["MemberPrimitiveTyped"]
}

func (objectNullRecord ObjectNullRecord) GetRecordType() int {
	return RecordTypeEnumMap["ObjectNull"]
}

// This one is different from the other recordbecause it usually is not processed within the 'context' of the member values, and needs to be called with information that is not present.
// These records are usually appended outside of the membervalues.
func (arraySinglePrimitiveRecord ArraySinglePrimitiveRecord) ToXMLBespoke() (ClassDataNode, bool) {
	classDataNode := ClassDataNode{}
	classDataNode.XMLName.Local = "SOAP-ENC:Array"
	classDataNode.ID = fmt.Sprintf("ref-%d", arraySinglePrimitiveRecord.ArrayInfo.ObjectID)
	classDataNode.addAttribute("xsi:type", "SOAP-ENC:base64")

	// encode to match xsi:type
	b64Content := make([]byte, base64.StdEncoding.EncodedLen(len(arraySinglePrimitiveRecord.Members)))
	base64.StdEncoding.Encode(b64Content, []byte(arraySinglePrimitiveRecord.Members))
	b64MemberContent := string(b64Content)
	classDataNode.Content = b64MemberContent

	return classDataNode, true
}

func (arraySinglePrimitiveRecord ArraySinglePrimitiveRecord) ToXML(_ ClassInfo, _ MemberTypeInfo, _ BinaryLibraryRecord, _ int, _ string) (MemberNode, bool) {
	output.PrintFrameworkError("ToXML for ArraySingleStringRecord cannot be used, call <instance>.ToXMLBespoke() instead. Note: uses different parameters.")

	return MemberNode{}, false
}

func (objectNullMultiple256Record ObjectNullMultiple256Record) ToXML(_ ClassInfo, _ MemberTypeInfo, _ BinaryLibraryRecord, _ int, _ string) (MemberNode, bool) {
	output.PrintFrameworkError("ToXML for ObjectNullMultiple256Record not yet implemented")

	return MemberNode{}, false
}

func (memberPrimitiveTypedRecord MemberPrimitiveTypedRecord) ToXML(_ ClassInfo, _ MemberTypeInfo, _ BinaryLibraryRecord, _ int, _ string) (MemberNode, bool) {
	output.PrintFrameworkError("ToXML for MemberPrimitiveTypedRecord not yet implemented")

	return MemberNode{}, false
}

func (binaryArrayRecord BinaryArrayRecord) ToXML(_ ClassInfo, _ MemberTypeInfo, _ BinaryLibraryRecord, _ int, _ string) (MemberNode, bool) {
	output.PrintFrameworkError("ToXML for BinaryArrayRecord not yet implemented")

	return MemberNode{}, false
}

func (arraySingleObjectRecord ArraySingleObjectRecord) ToXML(_ ClassInfo, _ MemberTypeInfo, _ BinaryLibraryRecord, _ int, _ string) (MemberNode, bool) {
	output.PrintFrameworkError("ToXML for ArraySingleObjectRecord not yet implemented")

	return MemberNode{}, false
}

func (arraySingleStringRecord ArraySingleStringRecord) ToXML(_ ClassInfo, _ MemberTypeInfo, _ BinaryLibraryRecord, _ int, _ string) (MemberNode, bool) {
	output.PrintFrameworkError("ToXML for ArraySingleStringRecord not yet implemented")

	return MemberNode{}, false
}

func (classWithIDRecord ClassWithIDRecord) ToXML(_ ClassInfo, _ MemberTypeInfo, _ BinaryLibraryRecord, _ int, _ string) (MemberNode, bool) {
	output.PrintFrameworkError("ToXML for ClassWithIDRecord not yet implemented")

	return MemberNode{}, false
}

func (binaryObjectString BinaryObjectString) ToXML(classInfo ClassInfo, memberTypeInfo MemberTypeInfo, _ BinaryLibraryRecord, currentIndex int, _ string) (MemberNode, bool) {
	memberNode := MemberNode{}
	memberNode.XMLName.Local = classInfo.MemberNames[currentIndex]
	memberNode.ID = fmt.Sprintf("ref-%d", binaryObjectString.ObjectID)
	memberNode.XsiType = "xsd:" + strings.ToLower(memberTypeInfo.BinaryTypes[currentIndex])
	memberNode.Content = escapeTags(binaryObjectString.Value)

	return memberNode, true
}

func (classWithMembersAndTypesRecord ClassWithMembersAndTypesRecord) ToXML(classInfo ClassInfo, _ MemberTypeInfo, binaryLibraryRecord BinaryLibraryRecord, currentIndex int, ns string) (MemberNode, bool) {
	memberNode := MemberNode{}
	memberNode.XMLName.Local = classInfo.MemberNames[currentIndex]
	memberNode.XsiType = "a1:" + classWithMembersAndTypesRecord.ClassInfo.GetBaseClassName()
	libURL := fmt.Sprintf("http://schemas.microsoft.com/clr/nsassem/%s/%s", classInfo.GetLeadingClassName(), binaryLibraryRecord.Library)
	memberNode.addAttribute("xmlns:"+ns, libURL)
	memberNode.Content = "Binary" // NOT 100% sure this is always the case but it was for DataSet. Once we find out if/when/why this is the case we can implement that logic

	return memberNode, true
}

func (systemClassWithMembersAndTypesRecord SystemClassWithMembersAndTypesRecord) ToXML(_ ClassInfo, _ MemberTypeInfo, _ BinaryLibraryRecord, _ int, _ string) (MemberNode, bool) {
	output.PrintFrameworkError("ToXML for SystemClassWithMembersAndTypesRecord not yet implemented")

	return MemberNode{}, false
}

func (serializationHeaderRecord SerializationHeaderRecord) ToXML(_ ClassInfo, _ MemberTypeInfo, _ BinaryLibraryRecord, _ int, _ string) (MemberNode, bool) {
	output.PrintFrameworkError("ToXML for SerializationHeaderRecord not yet implemented")

	return MemberNode{}, false
}

func (binaryLibraryRecord BinaryLibraryRecord) ToXML(_ ClassInfo, _ MemberTypeInfo, _ BinaryLibraryRecord, _ int, _ string) (MemberNode, bool) {
	output.PrintFrameworkError("ToXML for BinaryLibraryRecord not yet implemented")

	return MemberNode{}, false
}

func (memberReferenceRecord MemberReferenceRecord) ToXML(classInfo ClassInfo, _ MemberTypeInfo, _ BinaryLibraryRecord, currentIndex int, _ string) (MemberNode, bool) {
	memberNode := MemberNode{}
	memberNode.XMLName.Local = classInfo.MemberNames[currentIndex]
	memberNode.HREF = fmt.Sprintf("#ref-%d", memberReferenceRecord.IDRef)

	return memberNode, true
}

func (objectNullRecord ObjectNullRecord) ToXML(classInfo ClassInfo, _ MemberTypeInfo, _ BinaryLibraryRecord, currentIndex int, _ string) (MemberNode, bool) {
	memberNode := MemberNode{}
	memberNode.XMLName.Local = classInfo.MemberNames[currentIndex]
	memberNode.XsiType = "xsi:anyType"
	memberNode.XsiNull = "1"

	return memberNode, true
}

// ToRecordBin impls these exist to convert the struct into the binary stream that is expected by the serialized object format.
func (arraySingleStringRecord ArraySingleStringRecord) ToRecordBin() (string, bool) {
	recordByteString := string(byte(arraySingleStringRecord.GetRecordType()))
	objectIDString := transform.PackLittleInt32(arraySingleStringRecord.ArrayInfo.ObjectID)
	memberCount := transform.PackLittleInt32(arraySingleStringRecord.ArrayInfo.MemberCount)
	memberValuesString := ""
	for _, member := range arraySingleStringRecord.Members {
		memberRecord, ok := member.(Record)
		if ok {
			recordBinString, ok := memberRecord.ToRecordBin()
			if !ok {
				return "", false
			}
			memberValuesString += recordBinString

			continue
		}
		memberString, ok := member.(string)
		if ok {
			memberValuesString += memberString

			continue
		}
	}

	return recordByteString + objectIDString + memberCount + memberValuesString, true
}

func (binaryArrayRecord BinaryArrayRecord) ToRecordBin() (string, bool) {
	recordByteString := string(byte(binaryArrayRecord.GetRecordType()))
	objectIDString := transform.PackLittleInt32(binaryArrayRecord.ObjectID)
	binTypeEnumString := string(byte(binaryArrayRecord.BinaryArrayTypeEnum))
	rankString := transform.PackLittleInt32(binaryArrayRecord.Rank)
	var lengthsString string
	for _, length := range binaryArrayRecord.Lengths {
		lengthsString += transform.PackLittleInt32(length)
	}

	var lowerBoundsString string // only necessary for certain types
	if binaryArrayRecord.BinaryArrayTypeEnum > 2 {
		for _, bound := range binaryArrayRecord.LowerBounds {
			lowerBoundsString += transform.PackLittleInt32(bound)
		}
	}

	var addInfoString string
	for _, addInfo := range binaryArrayRecord.AdditionalTypeInfo {
		if addInfo == nil {
			output.PrintFrameworkError("Nil additional info provided")

			return "", false
		}

		typeInt, ok := addInfo.(int)
		if ok {
			addInfoString += string(byte(typeInt))

			continue
		}

		stringInput, ok := addInfo.(string)
		if ok {
			addInfoString += lengthPrefixedString(stringInput)

			continue
		}

		// handling ClassTypeInfo used for 'Class' type
		classTypeInfo, ok := addInfo.(ClassTypeInfo)
		if ok {
			addInfoString += lengthPrefixedString(classTypeInfo.TypeName)
			addInfoString += transform.PackLittleInt32(classTypeInfo.LibraryID)

			continue
		}
		output.PrintfFrameworkError("Unsupported additional info type provided %q", addInfo)

		return "", false
	}

	return recordByteString + objectIDString + binTypeEnumString + rankString + lengthsString + lowerBoundsString + string(byte(binaryArrayRecord.TypeEnum)) + addInfoString, true
}

func (arraySingleObjectRecord ArraySingleObjectRecord) ToRecordBin() (string, bool) {
	recordByteString := string(byte(arraySingleObjectRecord.GetRecordType()))
	objectIDString := transform.PackLittleInt32(arraySingleObjectRecord.ArrayInfo.ObjectID)
	memberCount := transform.PackLittleInt32(arraySingleObjectRecord.ArrayInfo.MemberCount)

	// handle member values
	memberValuesString := ""
	for _, member := range arraySingleObjectRecord.Members {
		memberRecord, ok := member.(Record)
		if ok {
			recordBinString, ok := memberRecord.ToRecordBin()
			if !ok {
				return "", false
			}
			memberValuesString += recordBinString

			continue
		}
	}

	return recordByteString + objectIDString + memberCount + memberValuesString, true
}

func (arraySinglePrimitiveRecord ArraySinglePrimitiveRecord) ToRecordBin() (string, bool) {
	recordByteString := string(byte(arraySinglePrimitiveRecord.GetRecordType()))
	objectIDString := transform.PackLittleInt32(arraySinglePrimitiveRecord.ArrayInfo.ObjectID)
	memberCount := transform.PackLittleInt32(arraySinglePrimitiveRecord.ArrayInfo.MemberCount)
	primitiveTypeString := string(byte(arraySinglePrimitiveRecord.PrimitiveTypeEnum))

	return recordByteString + objectIDString + memberCount + primitiveTypeString + arraySinglePrimitiveRecord.Members, true
}

func (objectNullMultiple256Record ObjectNullMultiple256Record) ToRecordBin() (string, bool) {
	recordByteString := string(byte(objectNullMultiple256Record.GetRecordType()))
	nullCountString := string(byte((objectNullMultiple256Record.NullCount)))
	if objectNullMultiple256Record.NullCount > 255 || objectNullMultiple256Record.NullCount < 0 {
		output.PrintFrameworkError("Invalid value for objectNullMultiple256Record.NullCount, MUST be between 0-255 (inclusive)")

		return "", false
	}

	return recordByteString + nullCountString, true
}

func (memberPrimitiveTypedRecord MemberPrimitiveTypedRecord) ToRecordBin() (string, bool) {
	recordByteString := string(byte(memberPrimitiveTypedRecord.GetRecordType()))
	typeEnumString := string([]byte{byte(memberPrimitiveTypedRecord.PrimitiveTypeEnum)})
	valueString := memberPrimitiveTypedRecord.Value.PrimToString()

	return recordByteString + typeEnumString + valueString, true
}

func (memberReferenceRecord MemberReferenceRecord) ToRecordBin() (string, bool) {
	recordByteString := string(byte(memberReferenceRecord.GetRecordType()))
	idRefString := transform.PackLittleInt32(memberReferenceRecord.IDRef)

	return recordByteString + idRefString, true
}

func (objectNullRecord ObjectNullRecord) ToRecordBin() (string, bool) {
	return string(byte(objectNullRecord.GetRecordType())), true
}

// https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-nrbf/a7e578d3-400a-4249-9424-7529d10d1b3c
func (serializationHeaderRecord SerializationHeaderRecord) ToRecordBin() (string, bool) {
	recordTypeEnumString := string(byte(serializationHeaderRecord.GetRecordType())) // 0
	rootIDString := transform.PackLittleInt32(serializationHeaderRecord.RootID)
	headerIDString := transform.PackLittleInt32(serializationHeaderRecord.HeaderID)
	majorVersion := transform.PackLittleInt32(1) // MUST be 1
	minorVersion := transform.PackLittleInt32(0) // MUST be 0

	return recordTypeEnumString + rootIDString + headerIDString + majorVersion + minorVersion, true
}

func (binaryLibraryRecord BinaryLibraryRecord) ToRecordBin() (string, bool) {
	recordTypeEnumString := string(byte(binaryLibraryRecord.GetRecordType()))
	idLEBytes := transform.PackLittleInt32(binaryLibraryRecord.ID)
	libName := lengthPrefixedString(binaryLibraryRecord.Library)

	return recordTypeEnumString + idLEBytes + libName, true
}

func (classWithIDRecord ClassWithIDRecord) ToRecordBin() (string, bool) {
	recordTypeEnumString := string(byte(classWithIDRecord.GetRecordType()))
	objectIDString := transform.PackLittleInt32(classWithIDRecord.ObjectID)
	metadataIDString := transform.PackLittleInt32(classWithIDRecord.MetadataID)
	memberValuesString := ""
	for _, memberValue := range classWithIDRecord.MemberValues {
		// handle record types
		memberRecord, ok := memberValue.(Record)
		if ok {
			recordBin, ok := memberRecord.ToRecordBin()
			if !ok {
				output.PrintFrameworkError("Failed to convert member value into record")

				return "", false
			}
			memberValuesString += recordBin

			continue
		}

		memberString, ok := memberValue.(string)
		if ok {
			memberValuesString += memberString

			continue
		}
		memberInt, ok := memberValue.(int) // Keeping these commented for now
		if ok {
			memberValuesString += transform.PackLittleInt32(memberInt)

			continue
		}
	}

	return recordTypeEnumString + objectIDString + metadataIDString + memberValuesString, true
}

// https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-nrbf/eb503ca5-e1f6-4271-a7ee-c4ca38d07996
func (binaryObjectString BinaryObjectString) ToRecordBin() (string, bool) {
	recordTypeEnumString := string(byte(binaryObjectString.GetRecordType()))
	objectIDString := transform.PackLittleInt32(binaryObjectString.ObjectID)
	prefixedValue := lengthPrefixedString(binaryObjectString.Value)

	return recordTypeEnumString + objectIDString + prefixedValue, true
}

func (systemClassWithMembersAndTypesRecord SystemClassWithMembersAndTypesRecord) ToRecordBin() (string, bool) {
	memberValuesString := ""
	for _, memberValue := range systemClassWithMembersAndTypesRecord.MemberValues {
		// handle record types
		memberRecord, ok := memberValue.(Record)
		if ok {
			recordBin, ok := memberRecord.ToRecordBin()
			if !ok {
				output.PrintFrameworkError("Failed to convert member value into record")

				return "", false
			}
			memberValuesString += recordBin

			continue
		}

		memberPrim, ok := memberValue.(Primitive)
		if ok {
			memberValuesString += memberPrim.PrimToString()

			continue
		}

		memberInt, ok := memberValue.(int)
		if ok {
			memberValuesString += transform.PackLittleInt32(memberInt)

			continue
		}

		memberBool, ok := memberValue.(bool)
		if ok {
			switch memberBool {
			case true:
				memberValuesString += "\x01"

				continue
			default:
				memberValuesString += "\x00"

				continue
			}
		}

		memberString, ok := memberValue.(string)
		if ok {
			memberValuesString += memberString

			continue
		}
	}
	recordTypeEnumString := string(byte(systemClassWithMembersAndTypesRecord.GetRecordType()))
	memberTypeInfoString, ok := systemClassWithMembersAndTypesRecord.MemberTypeInfo.ToBin()
	if !ok {
		return "", false
	}

	// objid, name, count, membernames//int8 type values+addInfo/the array of values
	return recordTypeEnumString + systemClassWithMembersAndTypesRecord.ClassInfo.ToBin() + memberTypeInfoString + memberValuesString, true
}

// ref: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-nrbf/847b0b6a-86af-4203-8ed0-f84345f845b9
func (classWithMembersAndTypesRecord ClassWithMembersAndTypesRecord) ToRecordBin() (string, bool) {
	memberValuesString := ""
	for _, memberValue := range classWithMembersAndTypesRecord.MemberValues {
		// handle record types
		memberRecord, ok := memberValue.(Record)
		if ok {
			recordBin, ok := memberRecord.ToRecordBin()
			if !ok {
				output.PrintFrameworkError("Failed to convert member value into record")

				return "", false
			}
			memberValuesString += recordBin

			continue
		}

		memberInt, ok := memberValue.(int)
		if ok {
			memberValuesString += transform.PackLittleInt32(memberInt)

			continue
		}
		memberPrim, ok := memberValue.(Primitive)
		if ok {
			memberValuesString += memberPrim.PrimToString()

			continue
		}

		memberBool, ok := memberValue.(bool)
		if ok {
			switch memberBool {
			case true:
				memberValuesString += "\x01"

				continue
			default:
				memberValuesString += "\x00"

				continue
			}
		}
		memberString, ok := memberValue.(string)
		if ok {
			memberValuesString += memberString

			continue
		}
	}
	recordTypeEnumString := string(byte(classWithMembersAndTypesRecord.GetRecordType())) // 5
	libraryIDString := transform.PackLittleInt32(classWithMembersAndTypesRecord.LibraryID)
	memberTypeInfoString, ok := classWithMembersAndTypesRecord.MemberTypeInfo.ToBin()
	if !ok {
		return "", false
	}

	// id, name, count, membernames+addinfo	the int8 values for types, the int32 ID, the array of values
	return recordTypeEnumString + classWithMembersAndTypesRecord.ClassInfo.ToBin() + memberTypeInfoString + libraryIDString + memberValuesString, true
}
