// Copyright IBM Corp. 2015, 2025
// SPDX-License-Identifier: BUSL-1.1

package taskrunner

import (
	"context"
	"fmt"
	"os"
	"path/filepath"
	"slices"
	"strings"

	log "github.com/hashicorp/go-hclog"
	"github.com/hashicorp/go-multierror"
	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
	cstructs "github.com/hashicorp/nomad/client/structs"
	"github.com/hashicorp/nomad/nomad/structs"
)

const (
	// consulTokenFilename is the name of the file holding the Consul SI token
	// inside the task's secret directory.
	consulTokenFilename = "consul_token"

	// consulTokenFilePerms is the level of file permissions granted on the file in
	// the secrets directory for the task
	consulTokenFilePerms = 0640
)

type consulHook struct {
	task          *structs.Task
	tokenDir      string
	hookResources *cstructs.AllocHookResources

	logger log.Logger
}

func newConsulHook(logger log.Logger, tr *TaskRunner) *consulHook {
	h := &consulHook{
		task:          tr.Task(),
		tokenDir:      tr.taskDir.SecretsDir,
		hookResources: tr.allocHookResources,
	}
	h.logger = logger.Named(h.Name())
	return h
}

func (*consulHook) Name() string {
	return "consul_task"
}

func (h *consulHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error {
	mErr := multierror.Error{}

	tokens := h.hookResources.GetConsulTokens()

	// Write tokens to tasks' secret dirs
	for _, t := range tokens {
		for tokenName, token := range t {
			s := strings.SplitN(tokenName, "/", 2)
			if len(s) < 2 {
				continue
			}
			identity := s[0]
			taskName := s[1]
			// do not write tokens that do not belong to any of this task's
			// identities
			if taskName != h.task.Name || !slices.ContainsFunc(
				h.task.Identities,
				func(id *structs.WorkloadIdentity) bool { return id.Name == identity }) &&
				identity != h.task.Identity.Name {
				continue
			}

			tokenPath := filepath.Join(h.tokenDir, consulTokenFilename)
			if err := os.WriteFile(tokenPath, []byte(token.SecretID), consulTokenFilePerms); err != nil {
				mErr.Errors = append(mErr.Errors, fmt.Errorf("failed to write Consul SI token: %w", err))
			}

			env := map[string]string{
				"CONSUL_TOKEN":      token.SecretID,
				"CONSUL_HTTP_TOKEN": token.SecretID,
			}

			resp.Env = env
		}
	}

	return mErr.ErrorOrNil()
}
