package main

import (
	"crypto/aes"
	"crypto/cipher"
	"crypto/rand"
)

func encrypt(plaintext, key []byte) ([]byte, error) {
	block, err := aes.NewCipher(key)
	if err != nil {
		return []byte{}, err
	}

	iv := make([]byte, aes.BlockSize)
	rand.Read(iv)

	plaintextPadded := pad(plaintext)
	encrypted := make([]byte, len(plaintextPadded))

	mode := cipher.NewCBCEncrypter(block, iv)
	mode.CryptBlocks(encrypted, plaintextPadded)

	return append(iv, encrypted...), nil
}

func decrypt(encryptedIVandData, key []byte) ([]byte, error) {
	block, err := aes.NewCipher(key)
	if err != nil {
		return []byte{}, err
	}

	iv := encryptedIVandData[0:aes.BlockSize]
	encrypted := encryptedIVandData[aes.BlockSize:]

	mode := cipher.NewCBCDecrypter(block, iv)

	plaintextPadded := make([]byte, len(encrypted))
	mode.CryptBlocks(plaintextPadded, encrypted)
	plaintext := unpad(plaintextPadded)

	return plaintext, nil
}

func unpad(in []byte) []byte {
	if len(in) == 0 {
		return nil
	}

	padding := in[len(in)-1]
	if int(padding) > len(in) || padding > aes.BlockSize {
		return nil
	} else if padding == 0 {
		return nil
	}

	for i := len(in) - 1; i > len(in)-int(padding)-1; i-- {
		if in[i] != padding {
			return nil
		}
	}
	return in[:len(in)-int(padding)]
}

func pad(in []byte) []byte {
	padding := aes.BlockSize - (len(in) % aes.BlockSize)
	for i := 0; i < padding; i++ {
		in = append(in, byte(padding))
	}
	return in
}