Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 234 additions & 0 deletions include/tvm/ffi/expected.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/ffi/expected.h
* \brief Runtime Expected container type for exception-free error handling.
*/
#ifndef TVM_FFI_EXPECTED_H_
#define TVM_FFI_EXPECTED_H_

#include <tvm/ffi/any.h>
#include <tvm/ffi/error.h>

#include <type_traits>
#include <utility>

namespace tvm {
namespace ffi {

/*!
* \brief Expected<T> provides exception-free error handling for FFI functions.
*
* Expected<T> is similar to Rust's Result<T, Error> or C++23's std::expected.
* It can hold either a success value of type T or an error of type Error.
*
* \tparam T The success type. Must be Any-compatible and cannot be Error.
*
* Usage:
* \code
* Expected<int> divide(int a, int b) {
* if (b == 0) {
* return ExpectedErr(Error("ValueError", "Division by zero"));
* }
* return ExpectedOk(a / b);
* }
*
* Expected<int> result = divide(10, 2);
* if (result.is_ok()) {
* int value = result.value();
* } else {
* Error err = result.error();
* }
* \endcode
*/
template <typename T>
class Expected {
public:
static_assert(!std::is_same_v<T, Error>, "Expected<Error> is not allowed. Use Error directly.");

/*!
* \brief Create an Expected with a success value.
* \param value The success value.
* \return Expected containing the success value.
*/
static Expected Ok(T value) { return Expected(Any(std::move(value))); }

/*!
* \brief Create an Expected with an error.
* \param error The error value.
* \return Expected containing the error.
*/
static Expected Err(Error error) { return Expected(Any(std::move(error))); }

/*!
* \brief Check if the Expected contains a success value.
* \return True if contains success value, false if contains error.
* \note Checks for Error first to handle cases where T is a base class of Error.
*/
TVM_FFI_INLINE bool is_ok() const { return !data_.as<Error>().has_value(); }

/*!
* \brief Check if the Expected contains an error.
* \return True if contains error, false if contains success value.
*/
TVM_FFI_INLINE bool is_err() const { return !is_ok(); }

/*!
* \brief Alias for is_ok().
* \return True if contains success value.
*/
TVM_FFI_INLINE bool has_value() const { return is_ok(); }

/*! \brief Access the success value. Throws the contained error if is_err(). */
TVM_FFI_INLINE T value() const& {
if (is_err()) throw data_.cast<Error>();
return data_.cast<T>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For large data, we can use move instead of copy, so I agree with Gemini that we can add an overload function here:

TVM_FFI_INLINE T value() && {
  if (is_err()) { throw error(); }
  return std::move(data_).cast<T>();
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree as well. I've added both const& and && qualified overloads for value():

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we might not need both here, @tqchen would like to hear your opinion if we wanna keep one of them

}
/*! \brief Access the success value (rvalue). Throws the contained error if is_err(). */
TVM_FFI_INLINE T value() && {
if (is_err()) throw std::move(data_).template cast<Error>();
return std::move(data_).template cast<T>();
}

/*! \brief Access the error. Throws RuntimeError if is_ok(). */
TVM_FFI_INLINE Error error() const& {
if (!is_err()) TVM_FFI_THROW(RuntimeError) << "Bad expected access: contains value, not error";
return data_.cast<Error>();
}
/*! \brief Access the error (rvalue). Throws RuntimeError if is_ok(). */
TVM_FFI_INLINE Error error() && {
if (!is_err()) TVM_FFI_THROW(RuntimeError) << "Bad expected access: contains value, not error";
return std::move(data_).template cast<Error>();
}

/*!
* \brief Get the success value or a default value.
* \param default_value The value to return if Expected contains an error.
* \return The success value if present, otherwise the default value.
*/
template <typename U = std::remove_cv_t<T>>
TVM_FFI_INLINE T value_or(U&& default_value) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function might be rarely used, because the philosophy of Expected<T> is handling error explicitly. Let's leave it for discussion here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think maybe we should keep this since that is kind of standard practice std::optional and C++23 std::expected both have value_or.

ref std::expected (C++23) and std::optional (C++17)

if (is_ok()) {
return data_.cast<T>();
}
return T(std::forward<U>(default_value));
}

private:
friend struct TypeTraits<Expected<T>>;

/*!
* \brief Private constructor from Any.
* \param data The data containing either T or Error.
* \note This constructor is used by TypeTraits for conversion.
*/
explicit Expected(Any data) : data_(std::move(data)) {
TVM_FFI_ICHECK(data_.as<T>().has_value() || data_.as<Error>().has_value())
<< "Expected must contain either T or Error";
}

Any data_; // Holds either T or Error
};

/*!
* \brief Helper function to create Expected::Ok with type deduction.
* \tparam T The success type (deduced from argument).
* \param value The success value.
* \return Expected<T> containing the success value.
*/
template <typename T>
TVM_FFI_INLINE Expected<T> ExpectedOk(T value) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is good to do round of API review, can you list all the APIs that are not conforming to std::expected, list their names, and discuss choice(i know some comes from rust API style, but good to be explicit). Would be good to list the APIs in the comment

return Expected<T>::Ok(std::move(value));
}

/*!
* \brief Helper function to create Expected::Err.
* \tparam T The success type (must be explicitly specified).
* \param error The error value.
* \return Expected<T> containing the error.
*/
template <typename T>
TVM_FFI_INLINE Expected<T> ExpectedErr(Error error) {
return Expected<T>::Err(std::move(error));
}

// TypeTraits specialization for Expected<T>
template <typename T>
inline constexpr bool use_default_type_traits_v<Expected<T>> = false;

template <typename T>
struct TypeTraits<Expected<T>> : public TypeTraitsBase {
TVM_FFI_INLINE static void CopyToAnyView(const Expected<T>& src, TVMFFIAny* result) {
if (src.is_err()) {
TypeTraits<Error>::CopyToAnyView(src.error(), result);
} else {
TypeTraits<T>::CopyToAnyView(src.value(), result);
}
}

TVM_FFI_INLINE static void MoveToAny(Expected<T> src, TVMFFIAny* result) {
if (src.is_err()) {
TypeTraits<Error>::MoveToAny(std::move(src).error(), result);
} else {
TypeTraits<T>::MoveToAny(std::move(src).value(), result);
}
}

TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
return TypeTraits<T>::CheckAnyStrict(src) || TypeTraits<Error>::CheckAnyStrict(src);
}

TVM_FFI_INLINE static Expected<T> CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
if (TypeTraits<T>::CheckAnyStrict(src)) {
return Expected<T>::Ok(TypeTraits<T>::CopyFromAnyViewAfterCheck(src));
}
return Expected<T>::Err(TypeTraits<Error>::CopyFromAnyViewAfterCheck(src));
}

TVM_FFI_INLINE static Expected<T> MoveFromAnyAfterCheck(TVMFFIAny* src) {
if (TypeTraits<T>::CheckAnyStrict(src)) {
return Expected<T>::Ok(TypeTraits<T>::MoveFromAnyAfterCheck(src));
}
return Expected<T>::Err(TypeTraits<Error>::MoveFromAnyAfterCheck(src));
}

TVM_FFI_INLINE static std::optional<Expected<T>> TryCastFromAnyView(const TVMFFIAny* src) {
if (auto opt = TypeTraits<T>::TryCastFromAnyView(src)) {
return Expected<T>::Ok(*std::move(opt));
}
if (auto opt_err = TypeTraits<Error>::TryCastFromAnyView(src)) {
return Expected<T>::Err(*std::move(opt_err));
}
return std::nullopt;
}

TVM_FFI_INLINE static std::string TypeStr() {
return "Expected<" + TypeTraits<T>::TypeStr() + ">";
}

TVM_FFI_INLINE static std::string TypeSchema() {
return R"({"type":"Expected","args":[)" + details::TypeSchema<T>::v() +
R"(,{"type":"ffi.Error"}]})";
}
};

} // namespace ffi
} // namespace tvm
#endif // TVM_FFI_EXPECTED_H_
58 changes: 58 additions & 0 deletions include/tvm/ffi/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,64 @@ class Function : public ObjectRef {
static_cast<FunctionObj*>(data_.get())->CallPacked(args.data(), args.size(), result);
}

/*!
* \brief Call the function and return Expected<T> for exception-free error handling.
* \tparam T The expected return type (default: Any).
* \param args The arguments to pass to the function.
* \return Expected<T> containing either the result or an error.
*
* This method provides exception-free calling by catching all exceptions
* and returning them as Error values in the Expected type.
*
* \code
* Function func = Function::GetGlobal("risky_function");
* Expected<int> result = func.CallExpected<int>(arg1, arg2);
* if (result.is_ok()) {
* int value = result.value();
* } else {
* Error err = result.error();
* }
* \endcode
*/
template <typename T = Any, typename... Args>
TVM_FFI_INLINE Expected<T> CallExpected(Args&&... args) const {
constexpr size_t kNumArgs = sizeof...(Args);
AnyView args_pack[kNumArgs > 0 ? kNumArgs : 1];
PackedArgs::Fill(args_pack, std::forward<Args>(args)...);

Any result;
FunctionObj* func_obj = static_cast<FunctionObj*>(data_.get());

// Use safe_call path to catch exceptions
int ret_code = func_obj->safe_call(func_obj, reinterpret_cast<const TVMFFIAny*>(args_pack),
kNumArgs, reinterpret_cast<TVMFFIAny*>(&result));

if (ret_code == 0) {
if constexpr (std::is_same_v<T, Any>) {
return Expected<T>::Ok(std::move(result));
} else {
// Check if result is Error (from Expected-returning function that returned Err)
if (result.template as<Error>().has_value()) {
return Expected<T>::Err(std::move(result).template cast<Error>());
}
// Try to extract as T
if (auto val = std::move(result).template as<T>()) {
return Expected<T>::Ok(std::move(*val));
}
return Expected<T>::Err(
Error("TypeError",
"CallExpected: result type mismatch, expected " + TypeTraits<T>::TypeStr(), ""));
}
} else if (ret_code == -2) {
// Environment error already set (e.g., Python KeyboardInterrupt)
// We still throw this since it's a signal, not a normal error
throw ::tvm::ffi::EnvErrorAlreadySet();
} else {
// Error occurred - retrieve from safe call context and return Err
return Expected<T>::Err(details::MoveFromSafeCallRaised());
}
}

/*! \return Whether the packed function is nullptr */
TVM_FFI_INLINE bool operator==(std::nullptr_t) const { return data_ == nullptr; }
/*! \return Whether the packed function is not nullptr */
Expand Down
21 changes: 21 additions & 0 deletions include/tvm/ffi/function_details.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@

namespace tvm {
namespace ffi {

// Forward declaration for Expected<T>
template <typename T>
class Expected;

namespace details {

template <typename ArgType>
Expand Down Expand Up @@ -67,6 +72,19 @@ static constexpr bool ArgSupported =
std::is_same_v<std::remove_const_t<std::remove_reference_t<T>>, AnyView> ||
TypeTraitsNoCR<T>::convert_enabled));

template <typename T>
struct is_expected : std::false_type {
using value_type = void;
};

template <typename T>
struct is_expected<Expected<T>> : std::true_type {
using value_type = T;
};

template <typename T>
inline constexpr bool is_expected_v = is_expected<T>::value;

// NOTE: return type can only support non-reference managed returns
template <typename T>
static constexpr bool RetSupported =
Expand Down Expand Up @@ -219,6 +237,9 @@ TVM_FFI_INLINE void unpack_call(std::index_sequence<Is...>, const std::string* o
// use index sequence to do recursive-less unpacking
if constexpr (std::is_same_v<R, void>) {
f(ArgValueWithContext<std::tuple_element_t<Is, PackedArgs>>{args, Is, optional_name, f_sig}...);
} else if constexpr (is_expected_v<R>) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This behavior can be a bit confusing. If the ffi.Function is explicitly returning Expected Value (instead of throw using the error handling mechanism, then the function should successully return instead of implicitly throw when error is found?

Mayb need a regression testcase for this

Copy link
Contributor

@Kathryn-cat Kathryn-cat Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the current design unwraps Expected<T> at the FFI boundary. CallExpected use safe_call to catch the error and return Expected<T>; on the contrary, the normal call path would throw an error. This might seem a bit convoluted. Do you have better suggested ideas?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two ways that Expected get returned

  • W0: Expected get "raised" internally by setting the TLS value via TVMFFISetRaised
  • W1: Expected get returned as a normal return value

And there are two ways to call a function now

  • C0: Normal call
  • C1: CallExpected

Would be good to discuss the overall relation in the mix of four cases

  • C0+ W0: we need to throw the error to caller
  • C0 + W1: no error should be thrown, the Expected should be returned to the caller
  • C1 + W0: we should return Expected with the error "raised" set to the returned Expected value
  • C1 + W1: we need to return the Expected, note that in this case it is harder to distinguish error being returned versus error being raised if the desirable logic is to return an error, but this is more of a rare case.

In any case, it would be good first to make sure behavior of C0 is correct, and the particular context seems to suggest C0+W1, in such case, we should return the value to the caller.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed breakdown of the four scenarios (C0/C1 + W0/W1) which really helped clarify the design considerations. Based on my understanding, the key question is about the C0 + W1 scenario (normal Call() on a function returning Expected). Here's how I see the two possible directions:

Option A: Keep current behavior (auto-unwrap)

  // Current implementation                                                                                                                                                      
  if (expected_result.is_ok()) {                                                                                                                                                 
      *rv = std::move(expected_result).value();  // unwrap value                                                                                                                 
  } else {                                                                                                                                                                       
      throw std::move(expected_result).error();  // throw error                                                                                                                  
  }  
  • Pros: Caller uses Call() and gets the value directly
  • Cons: Cannot distinguish "returned error" from "raised error"

Option B: No unwrap, return Expected directly

// Proposed change                                                                                                                                                             
*rv = f(...);  // return Expected as-is                                                                                                                                        
  • Pros: Clear distinction between returned vs raised errors
  • Cons: Caller must use Call<Expected>() and handle it explicitly

Personally, I'd slightly lean toward Option B after @tqchen's breakdown. If a function explicitly chooses to return Expected, I think we should respect that and let the caller receive it directly.

I'd like to confirm your preference before updating the implementation. Happy to add regression tests or update implementation for whichever approach we go with.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree option B is better

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated with option B. Please take another look, thanks!

*rv = f(ArgValueWithContext<std::tuple_element_t<Is, PackedArgs>>{args, Is, optional_name,
f_sig}...);
} else {
*rv = R(f(ArgValueWithContext<std::tuple_element_t<Is, PackedArgs>>{args, Is, optional_name,
f_sig}...));
Expand Down
1 change: 1 addition & 0 deletions include/tvm/ffi/tvm_ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include <tvm/ffi/dtype.h>
#include <tvm/ffi/endian.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/expected.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/function_details.h>
#include <tvm/ffi/memory.h>
Expand Down
Loading