-
Notifications
You must be signed in to change notification settings - Fork 143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Python] Qvector init from state #1713
[Python] Qvector init from state #1713
Conversation
…A/cuda-quantum into qvector-init-from-state2
…A/cuda-quantum into qvector-init-from-state2
…A/cuda-quantum into qvector-init-from-state2
…A/cuda-quantum into qvector-init-from-state2
…A/cuda-quantum into qvector-init-from-state2
…A/cuda-quantum into qvector-init-from-state2
…A/cuda-quantum into qvector-init-from-state-update
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 👍
b5ce9a7
into
NVIDIA:experimental/stateHandling
auto *ctx = eleTy.getContext(); | ||
auto eTy = cudaq::opt::factory::getCharType(ctx); | ||
return cudaq::opt::factory::getPointerType(eTy); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be simplified to just:
return cudaq::opt::factory::getPointerType(eleTy.getContext());
On the other hand, we don't use the eleTy
, so we could erase this function and just use getPointerType
at the call site.
if (isa<NoneType>(eleTy)) | ||
return factory::getPointerType(type.getContext()); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why doesn't the recursion handle this? The recursive call on line 2035 ought to handle pointers to pointers.
@@ -434,8 +434,10 @@ class GenerateKernelExecution | |||
hasTrailingData = true; | |||
continue; | |||
} | |||
if (isa<cudaq::cc::PointerType>(currEleTy)) | |||
if (isa<cudaq::cc::PointerType>(currEleTy) && | |||
!isStatePointerType(currEleTy)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: don't need braces.
@@ -933,6 +935,13 @@ class GenerateKernelExecution | |||
builder.create<cudaq::cc::StoreOp>(loc, endPtr, sret2); | |||
} | |||
|
|||
static bool isStatePointerType(mlir::Type ty) { | |||
if (auto ptrTy = dyn_cast<cudaq::cc::PointerType>(ty)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: no braces
return isa<cudaq::cc::PointerType>(i.getType()) && | ||
!isStatePointerType(i.getType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get this one. We're dropping hidden arguments. Why would a cudaq::state*
appear as a this
or sret
? Do we have an example?
@@ -1207,8 +1217,9 @@ class GenerateKernelExecution | |||
hasTrailingData = true; | |||
continue; | |||
} | |||
if (isa<cudaq::cc::PointerType>(inTy)) | |||
if (isa<cudaq::cc::PointerType>(inTy) && !isStatePointerType(inTy)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't look correct. We cannot pass a cudaq::state*
as a pointer-free value.
Description
Add initialization of
qvector
fromcudaq::State
Details
The changes that allow using
cudaq.State
in python kernel functions:Add
cc.StateType
to python bindingsAllow arguments of StateType to be passed to kernel functions
Made
cudaq.from_state
fail on incorrect precision dataUpdate
ast_bridge.py
to create MLIR for initializingqvector
from stateSupport capturing states from python non-kernel code
qalloc
inkernel_builder.py
Convert StateType to
ptr<i8>
incudaq::opt::initializeTypeConversions
Update
QmemRAIIOpRewrite
to call a newextern C
API on state pointers__quantum__rt__qubit_allocate_array_with_cudaq_state_ptr
Implement
__quantum__rt__qubit_allocate_array_with_cudaq_state_ptr
state_helper
to access private data)Update related intrinsics
__nvqpp_cudaq_state_numberOfQubits
__nvqpp_cudaq_state_vectorData
(to avoid a copy in MLIR)Add tests
cudaq.from_state
fails on incorrect precision dataSee python usage examples in [RFC] [Language] Quantum allocation with state initialization #1086
Towards: