diff --git a/web/hooks/use-save-referral.ts b/web/hooks/use-save-referral.ts index 788268b0..7772f9d2 100644 --- a/web/hooks/use-save-referral.ts +++ b/web/hooks/use-save-referral.ts @@ -18,10 +18,14 @@ export const useSaveReferral = ( referrer?: string } - const actualReferrer = referrer || options?.defaultReferrer + const referrerOrDefault = referrer || options?.defaultReferrer - if (!user && router.isReady && actualReferrer) { - writeReferralInfo(actualReferrer, options?.contractId, options?.groupId) + if (!user && router.isReady && referrerOrDefault) { + writeReferralInfo(referrerOrDefault, { + contractId: options?.contractId, + overwriteReferralUsername: referrer, + groupId: options?.groupId, + }) } }, [user, router, options]) } diff --git a/web/lib/firebase/users.ts b/web/lib/firebase/users.ts index 4f618586..5e00affe 100644 --- a/web/lib/firebase/users.ts +++ b/web/lib/firebase/users.ts @@ -96,22 +96,25 @@ const CACHED_REFERRAL_GROUP_ID_KEY = 'CACHED_REFERRAL_GROUP_KEY' export function writeReferralInfo( defaultReferrerUsername: string, - contractId?: string, - referralUsername?: string, - groupId?: string + otherOptions?: { + contractId?: string + overwriteReferralUsername?: string + groupId?: string + } ) { const local = safeLocalStorage() const cachedReferralUser = local?.getItem(CACHED_REFERRAL_USERNAME_KEY) + const { contractId, overwriteReferralUsername, groupId } = otherOptions || {} // Write the first referral username we see. if (!cachedReferralUser) local?.setItem( CACHED_REFERRAL_USERNAME_KEY, - referralUsername || defaultReferrerUsername + overwriteReferralUsername || defaultReferrerUsername ) // If an explicit referral query is passed, overwrite the cached referral username. - if (referralUsername) - local?.setItem(CACHED_REFERRAL_USERNAME_KEY, referralUsername) + if (overwriteReferralUsername) + local?.setItem(CACHED_REFERRAL_USERNAME_KEY, overwriteReferralUsername) // Always write the most recent explicit group invite query value if (groupId) local?.setItem(CACHED_REFERRAL_GROUP_ID_KEY, groupId)