import {isAfter, subMilliseconds} from 'date-fns';
import {AsyncSubject, Observable, of, Subscription} from 'rxjs';

import {AccessToken, TokenCreds} from '@azarus/api-contract';

import {API_GATEWAY_DEFAULT_ACCESS_TOKEN_REFRESH_SAFETY_FRAME_MS} from '../const';

export class ApiGatewayTokenRefreshHelper {
  private _tokenClaim: AsyncSubject<AccessToken | null> | null = null;
  private _claimersCount = 0;
  private _refreshAccessTokenSubscription: Subscription | null = null;

  public getFreshToken(
    credentials: TokenCreds,
    onRefreshTokenExpiry: () => Observable<AccessToken | null>,
    refreshAccessToken: () => Observable<AccessToken | null>,
    safetyFrame = API_GATEWAY_DEFAULT_ACCESS_TOKEN_REFRESH_SAFETY_FRAME_MS,
  ): Observable<AccessToken | null> {
    const dateNow = new Date();

    if (isAfter(dateNow, new Date(credentials.refreshTokenExpiresAt))) {
      return onRefreshTokenExpiry();
    }

    const adjustedAccessTokenExpiresAt = subMilliseconds(
      new Date(credentials.accessTokenExpiresAt),
      safetyFrame,
    );

    if (isAfter(dateNow, adjustedAccessTokenExpiresAt)) {
      return new Observable<AccessToken | null>((subscriber) => {
        let subscription: Subscription;

        // only first subscriber actually initiates the token refresh
        // the rest will just receive the result of it via shared token claim
        if (this._tokenClaim === null) {
          this._tokenClaim = new AsyncSubject();

          this._refreshAccessTokenSubscription = refreshAccessToken().subscribe(
            this._tokenClaim,
          );
        }

        this._claimersCount++;

        subscription = this._tokenClaim.subscribe(subscriber);

        return () => {
          subscription.unsubscribe();
          if (--this._claimersCount === 0) {
            this._tokenClaim = null;
            this._refreshAccessTokenSubscription?.unsubscribe();
          }
        };
      });
    }

    return of(credentials.accessToken);
  }
}
