Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync default payment method to the backend #10172

Merged
merged 5 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,14 @@ abstract class AbsFakeStripeRepository : StripeRepository {
TODO("Not yet implemented")
}

override suspend fun setDefaultPaymentMethod(
customerId: String,
paymentMethodId: String?,
options: ApiRequest.Options
): Result<Customer> {
TODO("Not yet implemented")
}

override suspend fun logOut(
consumerSessionClientSecret: String,
consumerAccountPublishableKey: String?,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,21 @@ class StripeApiRepository @JvmOverloads internal constructor(
}
}

override suspend fun setDefaultPaymentMethod(
customerId: String,
paymentMethodId: String?,
options: ApiRequest.Options
): Result<Customer> {
return fetchStripeModelResult(
apiRequest = apiRequestFactory.createPost(
url = getSetDefaultPaymentMethodUrl(customerId = customerId),
options = options,
params = mapOf("payment_method" to (paymentMethodId ?: ""))
),
jsonParser = CustomerJsonParser()
)
}

/**
* Create a [Token] using the input token parameters.
*
Expand Down Expand Up @@ -2139,6 +2154,16 @@ class StripeApiRepository @JvmOverloads internal constructor(
return getApiUrl("payment_methods/$paymentMethodId")
}

/**
* @return `https://api.stripe.com/v1/elements/customers/:customerId/set_default_payment_method`
*/
@VisibleForTesting
internal fun getSetDefaultPaymentMethodUrl(
customerId: String,
): String {
return getApiUrl("elements/customers/$customerId/set_default_payment_method")
}

private fun getApiUrl(path: String, vararg args: Any): String {
return getApiUrl(String.format(Locale.ENGLISH, path, *args))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,20 @@ interface StripeRepository {
options: ApiRequest.Options
): Result<PaymentMethod>

/**
* Set the customer's default payment method.
*
* @param customerId Id of the customer to update
* @param paymentMethodId Id of the payment method to set as the default. If null, the user's existing default
* payment method will be unset.
* */
@RestrictTo(RestrictTo.Scope.LIBRARY_GROUP)
suspend fun setDefaultPaymentMethod(
customerId: String,
paymentMethodId: String?,
options: ApiRequest.Options,
): Result<Customer>

@RestrictTo(RestrictTo.Scope.LIBRARY_GROUP)
suspend fun createToken(
tokenParams: TokenParams,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,17 @@ internal class StripeApiRepositoryTest {
assertThat(attachUrl).isEqualTo(expectedUrl)
}

@Test
fun testSetDefaultPaymentMethodUrl() {
val customerId = "cus_123"
val setDefaultPaymentMethodUrl = StripeApiRepository.getSetDefaultPaymentMethodUrl(
customerId
)
assertThat(setDefaultPaymentMethodUrl).isEqualTo(
"https://api.stripe.com/v1/elements/customers/$customerId/set_default_payment_method"
)
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a unit test for the parameters as well? The parameters tests are found found below the URL tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! Just added a couple

@Test
fun testGetDetachPaymentMethodUrl() {
val paymentMethodId = "pm_1ETDEa2eZvKYlo2CN5828c52"
Expand Down Expand Up @@ -1491,6 +1502,49 @@ internal class StripeApiRepositoryTest {
)
}

@Test
fun setDefaultPaymentMethod_sendPaymentMethodParameter() = runTest {
val stripeResponse = StripeResponse(
code = 200,
body = "",
headers = emptyMap()
)
whenever(stripeNetworkClient.executeRequest(any<ApiRequest>()))
.thenReturn(stripeResponse)

val expectedPaymentMethodId = "pm_123"
create().setDefaultPaymentMethod(
customerId = "cus_123",
paymentMethodId = expectedPaymentMethodId,
DEFAULT_OPTIONS,
)

verify(stripeNetworkClient).executeRequest(apiRequestArgumentCaptor.capture())
val apiRequest = apiRequestArgumentCaptor.firstValue
assertThat(apiRequest.params?.get("payment_method")).isEqualTo(expectedPaymentMethodId)
}

@Test
fun setDefaultPaymentMethod_sendsNullPaymentMethodAsEmptyString() = runTest {
val stripeResponse = StripeResponse(
code = 200,
body = "",
headers = emptyMap()
)
whenever(stripeNetworkClient.executeRequest(any<ApiRequest>()))
.thenReturn(stripeResponse)

create().setDefaultPaymentMethod(
customerId = "cus_123",
paymentMethodId = null,
DEFAULT_OPTIONS,
)

verify(stripeNetworkClient).executeRequest(apiRequestArgumentCaptor.capture())
val apiRequest = apiRequestArgumentCaptor.firstValue
assertThat(apiRequest.params?.get("payment_method")).isEqualTo("")
}

@Test
fun createCardPaymentMethod_setsCorrectPaymentUserAgent() =
runTest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,10 @@ internal class CustomerSessionCustomerSheetTest {
enqueueSetupIntentRetrieval()
enqueueSetupIntentConfirmation()

val paymentMethodId = "pm_12345"
enqueueElementsSession(
cards = listOf(
PaymentMethodFactory.card(id = "pm_12345").update(
PaymentMethodFactory.card(id = paymentMethodId).update(
last4 = "4242",
addCbcNetworks = false,
brand = CardBrand.Visa,
Expand All @@ -170,7 +171,7 @@ internal class CustomerSessionCustomerSheetTest {
page.clickSaveButton()
assertOnlySavedCardIsDisplayed()

page.clickConfirmButton()
context.markTestSucceeded()
}

private fun assertOnlySavedCardIsDisplayed() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1085,7 +1085,9 @@ internal class CustomerSheetViewModel(
private fun selectSavedPaymentMethod(savedPaymentSelection: PaymentSelection.Saved?) {
viewModelScope.launch(workContext) {
awaitSavedSelectionDataSource().setSavedSelection(
savedPaymentSelection?.toSavedSelection()
savedPaymentSelection?.toSavedSelection(),
shouldSyncDefault =
customerState.value.metadata?.customerMetadata?.isPaymentMethodSetAsDefaultEnabled == true,
).onSuccess {
confirmPaymentSelection(
paymentSelection = savedPaymentSelection,
Expand All @@ -1104,7 +1106,7 @@ internal class CustomerSheetViewModel(

private fun selectGooglePay() {
viewModelScope.launch(workContext) {
awaitSavedSelectionDataSource().setSavedSelection(SavedSelection.GooglePay)
awaitSavedSelectionDataSource().setSavedSelection(SavedSelection.GooglePay, shouldSyncDefault = false)
.onSuccess {
confirmPaymentSelection(
paymentSelection = PaymentSelection.GooglePay,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ internal class CustomerAdapterDataSource @Inject constructor(
}
}

override suspend fun setSavedSelection(selection: SavedSelection?) = runCatchingAdapterTask {
override suspend fun setSavedSelection(
selection: SavedSelection?,
shouldSyncDefault: Boolean
) = runCatchingAdapterTask {
customerAdapter.setSelectedPaymentOption(selection?.toPaymentOption())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ import com.stripe.android.model.ElementsSession
import com.stripe.android.paymentsheet.PrefsRepository
import com.stripe.android.paymentsheet.model.SavedSelection
import com.stripe.android.paymentsheet.model.toSavedSelection
import com.stripe.android.paymentsheet.repositories.CustomerRepository
import kotlinx.coroutines.withContext
import java.io.IOException
import javax.inject.Inject
import kotlin.coroutines.CoroutineContext

internal class CustomerSessionSavedSelectionDataSource @Inject constructor(
private val elementsSessionManager: CustomerSessionElementsSessionManager,
private val customerRepository: CustomerRepository,
private val prefsRepositoryFactory: @JvmSuppressWildcards (String) -> PrefsRepository,
@IOContext private val workContext: CoroutineContext,
) : CustomerSheetSavedSelectionDataSource {
Expand Down Expand Up @@ -63,18 +65,48 @@ internal class CustomerSessionSavedSelectionDataSource @Inject constructor(
}
}

override suspend fun setSavedSelection(selection: SavedSelection?): CustomerSheetDataResult<Unit> {
override suspend fun setSavedSelection(
selection: SavedSelection?,
shouldSyncDefault: Boolean,
): CustomerSheetDataResult<Unit> {
return withContext(workContext) {
createPrefsRepository().mapCatching { prefsRepository ->
val result = prefsRepository.setSavedSelection(selection)

if (!result) {
throw IOException("Unable to persist payment option $selection")
elementsSessionManager.fetchCustomerSessionEphemeralKey().mapCatching { ephemeralKey ->
if (shouldSyncDefault) {
saveSelectionToBackend(ephemeralKey, selection)
} else {
saveSelectionToPrefs(selection)
}
}.toCustomerSheetDataResult()
}
}

private suspend fun saveSelectionToPrefs(
selection: SavedSelection?
) {
createPrefsRepository().mapCatching { prefsRepository ->
val result = prefsRepository.setSavedSelection(selection)

if (!result) {
throw IOException("Unable to persist payment option $selection")
}
}
}

private suspend fun saveSelectionToBackend(
ephemeralKey: CachedCustomerEphemeralKey,
selection: SavedSelection?
) {
val paymentMethodId = (selection as? SavedSelection.PaymentMethod)?.id
customerRepository.setDefaultPaymentMethod(
paymentMethodId = paymentMethodId,
customerInfo = CustomerRepository.CustomerInfo(
id = ephemeralKey.customerId,
ephemeralKeySecret = ephemeralKey.ephemeralKey,
customerSessionClientSecret = ephemeralKey.customerSessionClientSecret,
)
).getOrThrow()
}

private suspend fun createPrefsRepository(): CustomerSheetDataResult<PrefsRepository> {
return elementsSessionManager.fetchCustomerSessionEphemeralKey().mapCatching { ephemeralKey ->
prefsRepositoryFactory(ephemeralKey.customerId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,8 @@ internal interface CustomerSheetSavedSelectionDataSource {
customerSessionElementsSession: CustomerSessionElementsSession?
): CustomerSheetDataResult<SavedSelection?>

suspend fun setSavedSelection(selection: SavedSelection?): CustomerSheetDataResult<Unit>
suspend fun setSavedSelection(
selection: SavedSelection?,
shouldSyncDefault: Boolean,
): CustomerSheetDataResult<Unit>
}
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,18 @@ internal class CustomerApiRepository @Inject constructor(
logger.error("Failed to update payment method $paymentMethodId.", it)
}

override suspend fun setDefaultPaymentMethod(
customerInfo: CustomerRepository.CustomerInfo,
paymentMethodId: String?
): Result<Customer> = stripeRepository.setDefaultPaymentMethod(
paymentMethodId = paymentMethodId,
customerId = customerInfo.id,
options = ApiRequest.Options(
apiKey = customerInfo.ephemeralKeySecret,
stripeAccount = lazyPaymentConfig.get().stripeAccountId,
)
)

private fun filterPaymentMethods(allPaymentMethods: List<PaymentMethod>): List<PaymentMethod> {
val paymentMethods = mutableListOf<PaymentMethod>()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ internal interface CustomerRepository {
params: PaymentMethodUpdateParams
): Result<PaymentMethod>

suspend fun setDefaultPaymentMethod(
customerInfo: CustomerInfo,
paymentMethodId: String?,
): Result<Customer>

data class CustomerInfo(
val id: String,
val ephemeralKeySecret: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class CustomerAdapterDataSourceTest {
),
)

val result = dataSource.setSavedSelection(SavedSelection.GooglePay)
val result = dataSource.setSavedSelection(SavedSelection.GooglePay, false)

assertThat(result).isInstanceOf<CustomerSheetDataResult.Success<Unit>>()
}
Expand All @@ -176,7 +176,7 @@ class CustomerAdapterDataSourceTest {
)
)

val result = dataSource.setSavedSelection(SavedSelection.GooglePay)
val result = dataSource.setSavedSelection(SavedSelection.GooglePay, false)

assertThat(result).isInstanceOf<CustomerSheetDataResult.Failure<Unit>>()

Expand All @@ -199,7 +199,7 @@ class CustomerAdapterDataSourceTest {
)
)

val result = dataSource.setSavedSelection(SavedSelection.GooglePay)
val result = dataSource.setSavedSelection(SavedSelection.GooglePay, false)

assertThat(result).isInstanceOf<CustomerSheetDataResult.Failure<Unit>>()

Expand Down
Loading
Loading