/*
Copyright 2020 The Knative Authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package kncloudevents

import (
	"context"
	"net"
	nethttp "net/http"
	"net/http/httptest"
	"sync/atomic"
	"testing"
	"time"

	"github.com/cloudevents/sdk-go/v2/binding/buffering"
	bindingtest "github.com/cloudevents/sdk-go/v2/binding/test"
	cehttp "github.com/cloudevents/sdk-go/v2/protocol/http"
	cetest "github.com/cloudevents/sdk-go/v2/test"
	"github.com/stretchr/testify/assert"
	"k8s.io/utils/pointer"

	duckv1 "knative.dev/eventing/pkg/apis/duck/v1"
	eventingduck "knative.dev/eventing/pkg/apis/duck/v1"
)

// Test The RetryConfigFromDeliverySpec() Functionality
func TestRetryConfigFromDeliverySpec(t *testing.T) {

	// Define The TestCase Structure
	type TestCase struct {
		name                     string
		retry                    int32
		backoffPolicy            duckv1.BackoffPolicyType
		backoffDelay             string
		expectedBackoffDurations []time.Duration
		wantErr                  bool
	}

	// Create The TestCases
	testcases := []TestCase{
		{
			name:          "Successful Linear Backoff 2500ms",
			retry:         int32(5),
			backoffPolicy: duckv1.BackoffPolicyLinear,
			backoffDelay:  "PT2.5S",
			expectedBackoffDurations: []time.Duration{
				2500 * time.Millisecond,
				2500 * time.Millisecond,
				2500 * time.Millisecond,
				2500 * time.Millisecond,
				2500 * time.Millisecond,
			},
			wantErr: false,
		},
		{
			name:          "Successful Exponential Backoff 1500ms",
			retry:         int32(5),
			backoffPolicy: duckv1.BackoffPolicyExponential,
			backoffDelay:  "PT1.5S",
			expectedBackoffDurations: []time.Duration{
				3 * time.Second,
				6 * time.Second,
				12 * time.Second,
				24 * time.Second,
				48 * time.Second,
			},
			wantErr: false,
		},
		{
			name:          "Successful Exponential Backoff 500ms",
			retry:         int32(5),
			backoffPolicy: duckv1.BackoffPolicyExponential,
			backoffDelay:  "PT0.5S",
			expectedBackoffDurations: []time.Duration{
				1 * time.Second,
				2 * time.Second,
				4 * time.Second,
				8 * time.Second,
				16 * time.Second,
			},
			wantErr: false,
		},
		{
			name:          "Invalid Backoff Delay",
			retry:         int32(5),
			backoffPolicy: duckv1.BackoffPolicyLinear,
			backoffDelay:  "FOO",
			wantErr:       true,
		},
	}

	// Loop Over The TestCases
	for _, testcase := range testcases {

		// Execute The TestCase
		t.Run(testcase.name, func(t *testing.T) {

			// Create The DeliverySpec To Test
			deliverySpec := duckv1.DeliverySpec{
				DeadLetterSink: nil,
				Retry:          &testcase.retry,
				BackoffPolicy:  &testcase.backoffPolicy,
				BackoffDelay:   &testcase.backoffDelay,
			}

			// Create The RetryConfig From The DeliverySpec
			retryConfig, err := RetryConfigFromDeliverySpec(deliverySpec)
			assert.Equal(t, testcase.wantErr, err != nil)

			// If Successful Then Validate The RetryConfig (Max & Backoff Calculations)
			if err == nil {
				assert.Equal(t, int(testcase.retry), retryConfig.RetryMax)
				for i := 1; i < int(testcase.retry); i++ {
					expectedBackoffDuration := testcase.expectedBackoffDurations[i-1]
					actualBackoffDuration := retryConfig.Backoff(i, nil)
					assert.Equal(t, expectedBackoffDuration, actualBackoffDuration)
				}
			}
		})
	}
}

func TestHttpMessageSenderSendWithRetries(t *testing.T) {
	t.Parallel()

	tests := []struct {
		name         string
		config       *RetryConfig
		wantStatus   int
		wantDispatch int
		wantErr      bool
	}{
		{
			name: "5 max retry",
			config: &RetryConfig{
				RetryMax: 5,
				CheckRetry: func(ctx context.Context, resp *nethttp.Response, err error) (bool, error) {
					return true, nil
				},
				Backoff: func(attemptNum int, resp *nethttp.Response) time.Duration {
					return time.Millisecond
				},
			},
			wantStatus:   nethttp.StatusServiceUnavailable,
			wantDispatch: 6,
			wantErr:      false,
		},
		{
			name: "1 max retry",
			config: &RetryConfig{
				RetryMax: 1,
				CheckRetry: func(ctx context.Context, resp *nethttp.Response, err error) (bool, error) {
					return true, nil
				},
				Backoff: func(attemptNum int, resp *nethttp.Response) time.Duration {
					return time.Millisecond
				},
			},
			wantStatus:   nethttp.StatusServiceUnavailable,
			wantDispatch: 2,
			wantErr:      false,
		},
		{
			name:         "with no retryConfig",
			wantStatus:   nethttp.StatusServiceUnavailable,
			wantDispatch: 1,
			wantErr:      false,
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			var n int32
			server := httptest.NewServer(nethttp.HandlerFunc(func(writer nethttp.ResponseWriter, request *nethttp.Request) {
				atomic.AddInt32(&n, 1)
				writer.WriteHeader(tt.wantStatus)
			}))

			sender := &HttpMessageSender{
				Client: nethttp.DefaultClient,
			}

			request, err := nethttp.NewRequest("POST", server.URL, nil)
			assert.Nil(t, err)
			got, err := sender.SendWithRetries(request, tt.config)
			if (err != nil) != tt.wantErr || got == nil {
				t.Errorf("SendWithRetries() error = %v, wantErr %v or got nil", err, tt.wantErr)
				return
			}
			if got.StatusCode != nethttp.StatusServiceUnavailable {
				t.Errorf("SendWithRetries() got = %v, want %v", got.StatusCode, nethttp.StatusServiceUnavailable)
				return
			}
			if count := int(atomic.LoadInt32(&n)); count != tt.wantDispatch {
				t.Fatalf("expected %d retries got %d", tt.config.RetryMax, count)
			}
		})
	}
}

func TestRetriesOnNetworkErrors(t *testing.T) {

	n := int32(10)
	linear := duckv1.BackoffPolicyLinear
	target := "127.0.0.1:63468"

	calls := make(chan struct{})
	defer close(calls)

	nCalls := int32(0)

	cont := make(chan struct{})
	defer close(cont)

	go func() {
		for range calls {

			nCalls++
			// Simulate that the target service is back up.
			//
			// First n/2-1 calls we get connection refused since there is no server running.
			// Now we start a server that responds with a retryable error, so we expect that
			// the client continues to retry for a different reason.
			//
			// The last time we return 200, so we don't expect a new retry.
			if n/2 == nCalls {

				l, err := net.Listen("tcp", target)
				assert.Nil(t, err)

				s := httptest.NewUnstartedServer(nethttp.HandlerFunc(func(writer nethttp.ResponseWriter, request *nethttp.Request) {
					if n-1 != nCalls {
						writer.WriteHeader(nethttp.StatusServiceUnavailable)
						return
					}
				}))
				defer s.Close() //nolint // defers in this range loop won't run unless the channel gets closed

				assert.Nil(t, s.Listener.Close())

				s.Listener = l

				s.Start()
			}
			cont <- struct{}{}
		}
	}()

	r, err := RetryConfigFromDeliverySpec(duckv1.DeliverySpec{
		Retry:         pointer.Int32Ptr(n),
		BackoffPolicy: &linear,
		BackoffDelay:  pointer.StringPtr("PT0.1S"),
	})
	assert.Nil(t, err)

	checkRetry := r.CheckRetry

	r.CheckRetry = func(ctx context.Context, resp *nethttp.Response, err error) (bool, error) {
		calls <- struct{}{}
		<-cont

		return checkRetry(ctx, resp, err)
	}

	req, err := nethttp.NewRequest("POST", "http://"+target, nil)
	assert.Nil(t, err)

	sender, err := NewHttpMessageSender(nil, "")
	assert.Nil(t, err)

	_, err = sender.SendWithRetries(req, &r)
	assert.Nil(t, err)

	// nCalls keeps track of how many times a call to check retry occurs.
	// Since the number of request are n + 1 and the last one is successful the expected number of calls are n.
	assert.Equal(t, n, nCalls, "expected %d got %d", n, nCalls)
}

func TestHTTPMessageSenderSendWithRetriesWithBufferedMessage(t *testing.T) {
	t.Parallel()

	const wantToSkip = 9
	config := &RetryConfig{
		RetryMax: wantToSkip,
		CheckRetry: func(ctx context.Context, resp *nethttp.Response, err error) (bool, error) {
			return true, nil
		},
		Backoff: func(attemptNum int, resp *nethttp.Response) time.Duration {
			return time.Millisecond * 50 * time.Duration(attemptNum)
		},
	}

	var n uint32
	server := httptest.NewServer(nethttp.HandlerFunc(func(writer nethttp.ResponseWriter, request *nethttp.Request) {
		thisReqN := atomic.AddUint32(&n, 1)
		if thisReqN <= wantToSkip {
			writer.WriteHeader(nethttp.StatusServiceUnavailable)
		} else {
			writer.WriteHeader(nethttp.StatusAccepted)
		}
	}))

	sender := &HttpMessageSender{
		Client: nethttp.DefaultClient,
	}

	request, err := nethttp.NewRequest("POST", server.URL, nil)
	assert.Nil(t, err)

	// Create a message similar to the one we send with channels
	mockMessage := bindingtest.MustCreateMockBinaryMessage(cetest.FullEvent())
	bufferedMessage, err := buffering.BufferMessage(context.TODO(), mockMessage)
	assert.Nil(t, err)

	err = cehttp.WriteRequest(context.TODO(), bufferedMessage, request)
	assert.Nil(t, err)

	got, err := sender.SendWithRetries(request, config)
	if err != nil {
		t.Fatalf("SendWithRetries() error = %v, wantErr nil", err)
	}
	if got.StatusCode != nethttp.StatusAccepted {
		t.Fatalf("SendWithRetries() got = %v, want %v", got.StatusCode, nethttp.StatusAccepted)
	}
	if count := atomic.LoadUint32(&n); count != wantToSkip+1 {
		t.Fatalf("expected %d count got %d", wantToSkip+1, count)
	}
}

func TestRetryConfigFromDeliverySpecCheckRetry(t *testing.T) {
	const retryMax = 10
	linear := eventingduck.BackoffPolicyLinear
	tests := []struct {
		name     string
		spec     eventingduck.DeliverySpec
		retryMax int
		wantErr  bool
	}{
		{
			name: "full delivery",
			spec: eventingduck.DeliverySpec{
				Retry:         pointer.Int32Ptr(10),
				BackoffPolicy: &linear,
				BackoffDelay:  pointer.StringPtr("PT1S"),
			},
			retryMax: 10,
			wantErr:  false,
		},
		{
			name: "only retry",
			spec: eventingduck.DeliverySpec{
				Retry:         pointer.Int32Ptr(10),
				BackoffPolicy: &linear,
			},
			retryMax: 10,
			wantErr:  false,
		},
		{
			name: "not ISO8601",
			spec: eventingduck.DeliverySpec{
				Retry:         pointer.Int32Ptr(10),
				BackoffDelay:  pointer.StringPtr("PP1"),
				BackoffPolicy: &linear,
			},
			retryMax: 10,
			wantErr:  true,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			got, err := RetryConfigFromDeliverySpec(tt.spec)
			if (err != nil) != tt.wantErr {
				t.Errorf("RetryConfigFromDeliverySpec() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if err != nil {
				return
			}

			if got.CheckRetry == nil {
				t.Errorf("CheckRetry must not be nil")
				return
			}
			if got.Backoff == nil {
				t.Errorf("Backoff must not be nil")
			}
			if got.RetryMax != tt.retryMax {
				t.Errorf("retryMax want %d got %d", tt.retryMax, got.RetryMax)
			}
		})
	}
}
